mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,104 +1,119 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
4
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams
6
- from msprobe.pytorch.free_benchmark.common.utils import TorchC
7
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
8
- NpuBaseLayer,
9
- )
10
-
11
-
12
- class BitNoiseLayer(NpuBaseLayer):
13
- def __init__(self, api_name):
14
- super().__init__(api_name)
15
- self.bit_mode = TorchC.bitwise_xor
16
- self.bit_tail: int = 1
17
- self.bit_type = None
18
-
19
- def add_bit_noise(self, tensor_obj):
20
- """
21
- 对输入添加噪声
22
- """
23
- # finfo应该列入黑名单
24
-
25
- if isinstance(tensor_obj, torch.Tensor):
26
- self._set_perturbation_bit(tensor_obj)
27
- if not self.pre_check(tensor_obj):
28
- return tensor_obj
29
- sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
30
- noise = TorchC.full(
31
- tensor_obj.shape,
32
- self.bit_tail,
33
- device=tensor_obj.device,
34
- dtype=self.bit_type,
35
- )
36
- result = tensor_obj.view(self.bit_type)
37
- result = TorchC.where(
38
- TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
39
- self.bit_mode(result, noise),
40
- result,
41
- ).view(tensor_obj.dtype)
42
-
43
- self.is_added = True
44
- return result
45
- if isinstance(tensor_obj, dict):
46
- return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
47
- if isinstance(tensor_obj, (tuple, list)):
48
- return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
49
- return tensor_obj
50
-
51
- def handle(self, params: DataParams):
52
- """
53
- 对输入添加扰动并返回
54
- """
55
- logger.info_on_rank_0(
56
- f"[msprobe] Free benchmark: Perturbation is "
57
- f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
58
- )
59
- params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
60
- return self.perturbed_result(params)
61
-
62
- def _check_details(self, tensor_obj):
63
- """
64
- 判断是否需要添加扰动, bit翻转
65
- """
66
- if not self.bit_type:
67
- logger.info_on_rank_0(
68
- f"[msprobe] Free Benchmark: For {self.api_name}, "
69
- f"dtype unsupported. Cancel perturbation."
70
- )
71
- return False
72
- if tensor_obj.numel() == 0:
73
- logger.warning_on_rank_0(
74
- f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
75
- f" Cancel adding noise."
76
- )
77
- return False
78
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
79
- tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
80
- )
81
- try:
82
- max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
83
- except Exception:
84
- logger.warning_on_rank_0(
85
- f"[msprobe] Free Benchmark: For {self.api_name}, "
86
- f"when calculate maximun value, tensor is changed to float32."
87
- )
88
- max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
89
- if max_val < abs_tol:
90
- logger.info_on_rank_0(
91
- f"[msprobe] Free Benchmark: For {self.api_name}, "
92
- f"Maximun value is less than the minimun threshold. Cancel add noise."
93
- )
94
- return False
95
- return True
96
-
97
- def _set_perturbation_bit(self, tensor_obj):
98
- """
99
- 根据不同浮点数确定不同位数扰动值
100
- """
101
- bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
102
- if bit_len_type:
103
- self.bit_tail = 1
104
- self.bit_type = bit_len_type
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.pytorch.free_benchmark import logger
18
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
19
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
20
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
21
+ from msprobe.pytorch.free_benchmark.common.utils import TorchC
22
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
23
+ NpuBaseLayer,
24
+ )
25
+
26
+
27
+ class BitNoiseLayer(NpuBaseLayer):
28
+ def __init__(self, api_name):
29
+ super().__init__(api_name)
30
+ self.bit_mode = TorchC.bitwise_xor
31
+ self.bit_tail: int = 1
32
+ self.bit_type = None
33
+
34
+ def add_bit_noise(self, tensor_obj):
35
+ """
36
+ 对输入添加噪声
37
+ """
38
+ # finfo应该列入黑名单
39
+
40
+ if isinstance(tensor_obj, torch.Tensor):
41
+ self._set_perturbation_bit(tensor_obj)
42
+ if not self.pre_check(tensor_obj):
43
+ return tensor_obj
44
+ sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
45
+ noise = TorchC.full(
46
+ tensor_obj.shape,
47
+ self.bit_tail,
48
+ device=tensor_obj.device,
49
+ dtype=self.bit_type,
50
+ )
51
+ result = tensor_obj.view(self.bit_type)
52
+ result = TorchC.where(
53
+ TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
54
+ self.bit_mode(result, noise),
55
+ result,
56
+ ).view(tensor_obj.dtype)
57
+
58
+ self.is_added = True
59
+ return result
60
+ if isinstance(tensor_obj, dict):
61
+ return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
62
+ if isinstance(tensor_obj, (tuple, list)):
63
+ return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
64
+ return tensor_obj
65
+
66
+ def handle(self, params: DataParams):
67
+ """
68
+ 对输入添加扰动并返回
69
+ """
70
+ logger.info_on_rank_0(
71
+ f"[msprobe] Free benchmark: Perturbation is "
72
+ f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
73
+ )
74
+ params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
75
+ return self.perturbed_result(params)
76
+
77
+ def _check_details(self, tensor_obj):
78
+ """
79
+ 判断是否需要添加扰动, bit翻转
80
+ """
81
+ if not self.bit_type:
82
+ logger.info_on_rank_0(
83
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
84
+ f"dtype unsupported. Cancel perturbation."
85
+ )
86
+ return False
87
+ if tensor_obj.numel() == 0:
88
+ logger.warning_on_rank_0(
89
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
90
+ f" Cancel adding noise."
91
+ )
92
+ return False
93
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
94
+ tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
95
+ )
96
+ try:
97
+ max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
98
+ except Exception:
99
+ logger.warning_on_rank_0(
100
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
101
+ f"when calculate maximun value, tensor is changed to float32."
102
+ )
103
+ max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
104
+ if max_val < abs_tol:
105
+ logger.info_on_rank_0(
106
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
107
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
108
+ )
109
+ return False
110
+ return True
111
+
112
+ def _set_perturbation_bit(self, tensor_obj):
113
+ """
114
+ 根据不同浮点数确定不同位数扰动值
115
+ """
116
+ bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
117
+ if bit_len_type:
118
+ self.bit_tail = 1
119
+ self.bit_type = bit_len_type
@@ -1,63 +1,87 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
- from msprobe.pytorch.free_benchmark.common.params import DataParams
5
- from msprobe.pytorch.free_benchmark.common.utils import TorchC
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
7
- NpuBaseLayer,
8
- )
9
-
10
-
11
- class ChangeValueLayer(NpuBaseLayer):
12
- def __init__(self, api_name):
13
- super().__init__(api_name)
14
- self.head: int = 0
15
- self.tail: int = -1
16
-
17
- def change_value(self, tensor_obj):
18
- """
19
- 交换张量首尾
20
- """
21
- if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
22
- new_tensor = TorchC.clone(tensor_obj)
23
- if new_tensor.ndim == 1:
24
- temp_first = TorchC.clone(new_tensor[self.head])
25
- temp_last = TorchC.clone(new_tensor[self.tail])
26
- new_tensor[self.head] = temp_last
27
- new_tensor[self.tail] = temp_first
28
- else:
29
- temp_first = TorchC.clone(new_tensor[self.head][self.head])
30
- temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
31
- new_tensor[self.head][self.head] = temp_last
32
- new_tensor[self.tail][self.tail] = temp_first
33
-
34
- self.is_added = True
35
- return new_tensor
36
- if isinstance(tensor_obj, dict):
37
- return {key: self.change_value(value) for key, value in tensor_obj.items()}
38
- if isinstance(tensor_obj, (tuple, list)):
39
- return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
40
- return tensor_obj
41
-
42
- def handle(self, params: DataParams):
43
- """
44
- 对输入添加扰动并返回
45
- """
46
- logger.info_on_rank_0(
47
- f"[msprobe] Free benchmark: Perturbation is "
48
- f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
49
- )
50
- params.perturbed_value = self.change_value(params.args[params.valid_input_index])
51
- return self.perturbed_result(params)
52
-
53
- def _check_details(self, tensor_obj):
54
- """
55
- 判断是否需要添加扰动, 首尾值交换
56
- """
57
- if tensor_obj.size(0) < 2:
58
- logger.info_on_rank_0(
59
- f"[msprobe] Free Benchmark: For {self.api_name}, "
60
- f"size 0 must greater than 1. Cancel change value."
61
- )
62
- return False
63
- return True
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.pytorch.free_benchmark import logger
18
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
19
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
20
+ from msprobe.pytorch.free_benchmark.common.utils import TorchC
21
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
22
+ NpuBaseLayer,
23
+ )
24
+
25
+
26
+ class ChangeValueLayer(NpuBaseLayer):
27
+ def __init__(self, api_name):
28
+ super().__init__(api_name)
29
+ self.head: int = 0
30
+ self.tail: int = -1
31
+
32
+ def change_value(self, tensor_obj):
33
+ """
34
+ 交换张量首尾
35
+ """
36
+ if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
37
+ new_tensor = TorchC.clone(tensor_obj)
38
+ if new_tensor.ndim == 1:
39
+ temp_first = TorchC.clone(new_tensor[self.head])
40
+ temp_last = TorchC.clone(new_tensor[self.tail])
41
+ new_tensor[self.head] = temp_last
42
+ new_tensor[self.tail] = temp_first
43
+ else:
44
+ temp_first = TorchC.clone(new_tensor[self.head][self.head])
45
+ temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
46
+ new_tensor[self.head][self.head] = temp_last
47
+ new_tensor[self.tail][self.tail] = temp_first
48
+
49
+ self.is_added = True
50
+ return new_tensor
51
+ if isinstance(tensor_obj, dict):
52
+ return {key: self.change_value(value) for key, value in tensor_obj.items()}
53
+ if isinstance(tensor_obj, (tuple, list)):
54
+ return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
55
+ return tensor_obj
56
+
57
+ def handle(self, params: DataParams):
58
+ """
59
+ 对输入添加扰动并返回
60
+ """
61
+ logger.info_on_rank_0(
62
+ f"[msprobe] Free benchmark: Perturbation is "
63
+ f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
64
+ )
65
+ params.perturbed_value = self.change_value(params.args[params.valid_input_index])
66
+ return self.perturbed_result(params)
67
+
68
+ def _check_details(self, tensor_obj):
69
+ """
70
+ 判断是否需要添加扰动, 首尾值交换
71
+ """
72
+ # 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
73
+ if tensor_obj.ndim > 1:
74
+ if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
75
+ logger.info_on_rank_0(
76
+ f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
77
+ f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
78
+ )
79
+ return False
80
+ # 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
81
+ elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
82
+ logger.info_on_rank_0(
83
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
84
+ f"0-dimension must greater than 1. Cancel change value."
85
+ )
86
+ return False
87
+ return True
@@ -1,68 +1,83 @@
1
- import torch
2
- from msprobe.core.common.const import Const
3
- from msprobe.pytorch.free_benchmark import logger
4
- from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
6
- from msprobe.pytorch.free_benchmark.common.params import DataParams
7
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
8
- NpuBaseLayer,
9
- )
10
-
11
-
12
- class ImprovePrecisionLayer(NpuBaseLayer):
13
-
14
- def improve_tensor_precision(self, tensor_obj):
15
- if (
16
- isinstance(tensor_obj, torch.Tensor)
17
- and torch.is_floating_point(tensor_obj)
18
- and tensor_obj.dtype not in [torch.float32, torch.float64]
19
- ):
20
- self._set_improve_values(tensor_obj)
21
- tensor_obj = self._change_dtype(tensor_obj)
22
- self.is_added = True
23
- return tensor_obj
24
- if isinstance(tensor_obj, dict):
25
- return {
26
- key: self.improve_tensor_precision(value)
27
- for key, value in tensor_obj.items()
28
- }
29
- if isinstance(tensor_obj, (tuple, list)):
30
- return type(tensor_obj)(
31
- [self.improve_tensor_precision(value) for value in tensor_obj]
32
- )
33
- return tensor_obj
34
-
35
- def handle(self, params: DataParams):
36
- logger.info_on_rank_0(
37
- f"[msprobe] Free benchmark: Perturbation is "
38
- f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
39
- )
40
- new_args = self.improve_tensor_precision(params.args)
41
- if params.fuzz_stage == Const.BACKWARD:
42
- new_kwargs = {}
43
- else:
44
- new_kwargs = self.improve_tensor_precision(params.kwargs)
45
- # 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
46
- if not self.is_added:
47
- return params.perturbed_result
48
- if "inplace" in new_kwargs:
49
- new_kwargs["inplace"] = False
50
- params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
51
- return params.perturbed_result
52
-
53
- def _set_improve_values(self, inputs):
54
- if inputs.dtype in [torch.float16, torch.bfloat16]:
55
- self.perturbed_value = torch.float32
56
-
57
- def _change_dtype(self, inputs):
58
- if hasattr(inputs, CommonField.DEVICE):
59
- device = inputs.device
60
- if device is CommonField.META:
61
- new_inputs = inputs.to(
62
- device=CommonField.META, dtype=self.perturbed_value
63
- )
64
- else:
65
- new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
66
- else:
67
- new_inputs = inputs.to(dtype=self.perturbed_value)
68
- return new_inputs
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.pytorch.free_benchmark import logger
19
+ from msprobe.pytorch.free_benchmark.common.constant import CommonField
20
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
21
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
22
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
23
+ NpuBaseLayer,
24
+ )
25
+
26
+
27
+ class ImprovePrecisionLayer(NpuBaseLayer):
28
+
29
+ def improve_tensor_precision(self, tensor_obj):
30
+ if (
31
+ isinstance(tensor_obj, torch.Tensor)
32
+ and torch.is_floating_point(tensor_obj)
33
+ and tensor_obj.dtype not in [torch.float32, torch.float64]
34
+ ):
35
+ self._set_improve_values(tensor_obj)
36
+ tensor_obj = self._change_dtype(tensor_obj)
37
+ self.is_added = True
38
+ return tensor_obj
39
+ if isinstance(tensor_obj, dict):
40
+ return {
41
+ key: self.improve_tensor_precision(value)
42
+ for key, value in tensor_obj.items()
43
+ }
44
+ if isinstance(tensor_obj, (tuple, list)):
45
+ return type(tensor_obj)(
46
+ [self.improve_tensor_precision(value) for value in tensor_obj]
47
+ )
48
+ return tensor_obj
49
+
50
+ def handle(self, params: DataParams):
51
+ logger.info_on_rank_0(
52
+ f"[msprobe] Free benchmark: Perturbation is "
53
+ f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
54
+ )
55
+ new_args = self.improve_tensor_precision(params.args)
56
+ if params.fuzz_stage == Const.BACKWARD:
57
+ new_kwargs = {}
58
+ else:
59
+ new_kwargs = self.improve_tensor_precision(params.kwargs)
60
+ # 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
61
+ if not self.is_added:
62
+ return params.perturbed_result
63
+ if "inplace" in new_kwargs:
64
+ new_kwargs["inplace"] = False
65
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
66
+ return params.perturbed_result
67
+
68
+ def _set_improve_values(self, inputs):
69
+ if inputs.dtype in [torch.float16, torch.bfloat16]:
70
+ self.perturbed_value = torch.float32
71
+
72
+ def _change_dtype(self, inputs):
73
+ if hasattr(inputs, CommonField.DEVICE):
74
+ device = inputs.device
75
+ if device is CommonField.META:
76
+ new_inputs = inputs.to(
77
+ device=CommonField.META, dtype=self.perturbed_value
78
+ )
79
+ else:
80
+ new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
81
+ else:
82
+ new_inputs = inputs.to(dtype=self.perturbed_value)
83
+ return new_inputs
@@ -1,28 +1,43 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
- from msprobe.pytorch.free_benchmark.common.params import DataParams
5
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
- NpuBaseLayer,
7
- )
8
-
9
-
10
- class NoChangeLayer(NpuBaseLayer):
11
-
12
- def no_change(self, tensor_obj):
13
- """
14
- 不对输入做任何改变、直接二次执行
15
- """
16
- self.is_added = True
17
- return tensor_obj
18
-
19
- def handle(self, params: DataParams):
20
- """
21
- 对输入添加扰动并返回
22
- """
23
- logger.info_on_rank_0(
24
- f"[msprobe] Free benchmark: Perturbation is "
25
- f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
- )
27
- params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
- return self.perturbed_result(params)
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.pytorch.free_benchmark import logger
18
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
19
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
20
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
21
+ NpuBaseLayer,
22
+ )
23
+
24
+
25
+ class NoChangeLayer(NpuBaseLayer):
26
+
27
+ def no_change(self, tensor_obj):
28
+ """
29
+ 不对输入做任何改变、直接二次执行
30
+ """
31
+ self.is_added = True
32
+ return tensor_obj
33
+
34
+ def handle(self, params: DataParams):
35
+ """
36
+ 对输入添加扰动并返回
37
+ """
38
+ logger.info_on_rank_0(
39
+ f"[msprobe] Free benchmark: Perturbation is "
40
+ f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
41
+ )
42
+ params.perturbed_value = self.no_change(params.args[params.valid_input_index])
43
+ return self.perturbed_result(params)