mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,90 +1,90 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
-
NpuBaseLayer,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class AddNoiseLayer(NpuBaseLayer):
|
|
13
|
-
|
|
14
|
-
def add_noise(self, tensor_obj):
|
|
15
|
-
if isinstance(tensor_obj, torch.Tensor):
|
|
16
|
-
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
17
|
-
tensor_obj.dtype
|
|
18
|
-
)
|
|
19
|
-
if not self.pre_check(tensor_obj):
|
|
20
|
-
return tensor_obj
|
|
21
|
-
noise = self._get_noise(tensor_obj)
|
|
22
|
-
result = TorchC.where(
|
|
23
|
-
TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5),
|
|
24
|
-
TorchC.add(noise, tensor_obj),
|
|
25
|
-
tensor_obj,
|
|
26
|
-
).to(tensor_obj.dtype)
|
|
27
|
-
self.is_added = True
|
|
28
|
-
return result
|
|
29
|
-
if isinstance(tensor_obj, dict):
|
|
30
|
-
return {key: self.add_noise(value) for key, value in tensor_obj.items()}
|
|
31
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
32
|
-
return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
|
|
33
|
-
return tensor_obj
|
|
34
|
-
|
|
35
|
-
def handle(self, params: DataParams):
|
|
36
|
-
"""
|
|
37
|
-
对输入添加扰动并返回
|
|
38
|
-
"""
|
|
39
|
-
logger.info_on_rank_0(
|
|
40
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
41
|
-
f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
|
|
42
|
-
)
|
|
43
|
-
params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
|
|
44
|
-
return self.perturbed_result(params)
|
|
45
|
-
|
|
46
|
-
def _get_noise(self, tensor_obj):
|
|
47
|
-
dtype = tensor_obj.dtype
|
|
48
|
-
device = str(tensor_obj.device)
|
|
49
|
-
noise = TorchC.full(
|
|
50
|
-
tensor_obj.shape,
|
|
51
|
-
self.perturbed_value,
|
|
52
|
-
device=device,
|
|
53
|
-
dtype=dtype,
|
|
54
|
-
)
|
|
55
|
-
return noise
|
|
56
|
-
|
|
57
|
-
def _check_details(self, tensor_obj):
|
|
58
|
-
"""
|
|
59
|
-
判断是否需要添加扰动
|
|
60
|
-
"""
|
|
61
|
-
if not self.perturbed_value:
|
|
62
|
-
logger.warning_on_rank_0(
|
|
63
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
64
|
-
f"dtype unsupported. Cancel perturbation."
|
|
65
|
-
)
|
|
66
|
-
return False
|
|
67
|
-
if tensor_obj.numel() == 0:
|
|
68
|
-
logger.warning_on_rank_0(
|
|
69
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
70
|
-
f" Cancel adding noise."
|
|
71
|
-
)
|
|
72
|
-
return False
|
|
73
|
-
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
74
|
-
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
75
|
-
)
|
|
76
|
-
try:
|
|
77
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
78
|
-
except Exception:
|
|
79
|
-
logger.warning_on_rank_0(
|
|
80
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
81
|
-
f"when calculate maximun value, tensor is changed to float32."
|
|
82
|
-
)
|
|
83
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
84
|
-
if max_val < abs_tol:
|
|
85
|
-
logger.warning_on_rank_0(
|
|
86
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
87
|
-
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
88
|
-
)
|
|
89
|
-
return False
|
|
90
|
-
return True
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AddNoiseLayer(NpuBaseLayer):
|
|
13
|
+
|
|
14
|
+
def add_noise(self, tensor_obj):
|
|
15
|
+
if isinstance(tensor_obj, torch.Tensor):
|
|
16
|
+
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
17
|
+
tensor_obj.dtype
|
|
18
|
+
)
|
|
19
|
+
if not self.pre_check(tensor_obj):
|
|
20
|
+
return tensor_obj
|
|
21
|
+
noise = self._get_noise(tensor_obj)
|
|
22
|
+
result = TorchC.where(
|
|
23
|
+
TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5),
|
|
24
|
+
TorchC.add(noise, tensor_obj),
|
|
25
|
+
tensor_obj,
|
|
26
|
+
).to(tensor_obj.dtype)
|
|
27
|
+
self.is_added = True
|
|
28
|
+
return result
|
|
29
|
+
if isinstance(tensor_obj, dict):
|
|
30
|
+
return {key: self.add_noise(value) for key, value in tensor_obj.items()}
|
|
31
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
32
|
+
return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
|
|
33
|
+
return tensor_obj
|
|
34
|
+
|
|
35
|
+
def handle(self, params: DataParams):
|
|
36
|
+
"""
|
|
37
|
+
对输入添加扰动并返回
|
|
38
|
+
"""
|
|
39
|
+
logger.info_on_rank_0(
|
|
40
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
41
|
+
f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
|
|
42
|
+
)
|
|
43
|
+
params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
|
|
44
|
+
return self.perturbed_result(params)
|
|
45
|
+
|
|
46
|
+
def _get_noise(self, tensor_obj):
|
|
47
|
+
dtype = tensor_obj.dtype
|
|
48
|
+
device = str(tensor_obj.device)
|
|
49
|
+
noise = TorchC.full(
|
|
50
|
+
tensor_obj.shape,
|
|
51
|
+
self.perturbed_value,
|
|
52
|
+
device=device,
|
|
53
|
+
dtype=dtype,
|
|
54
|
+
)
|
|
55
|
+
return noise
|
|
56
|
+
|
|
57
|
+
def _check_details(self, tensor_obj):
|
|
58
|
+
"""
|
|
59
|
+
判断是否需要添加扰动
|
|
60
|
+
"""
|
|
61
|
+
if not self.perturbed_value:
|
|
62
|
+
logger.warning_on_rank_0(
|
|
63
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
64
|
+
f"dtype unsupported. Cancel perturbation."
|
|
65
|
+
)
|
|
66
|
+
return False
|
|
67
|
+
if tensor_obj.numel() == 0:
|
|
68
|
+
logger.warning_on_rank_0(
|
|
69
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
70
|
+
f" Cancel adding noise."
|
|
71
|
+
)
|
|
72
|
+
return False
|
|
73
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
74
|
+
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
75
|
+
)
|
|
76
|
+
try:
|
|
77
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
78
|
+
except Exception:
|
|
79
|
+
logger.warning_on_rank_0(
|
|
80
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
81
|
+
f"when calculate maximun value, tensor is changed to float32."
|
|
82
|
+
)
|
|
83
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
84
|
+
if max_val < abs_tol:
|
|
85
|
+
logger.warning_on_rank_0(
|
|
86
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
87
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
88
|
+
)
|
|
89
|
+
return False
|
|
90
|
+
return True
|
|
@@ -1,104 +1,104 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
-
NpuBaseLayer,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class BitNoiseLayer(NpuBaseLayer):
|
|
13
|
-
def __init__(self, api_name):
|
|
14
|
-
super().__init__(api_name)
|
|
15
|
-
self.bit_mode = TorchC.bitwise_xor
|
|
16
|
-
self.bit_tail: int = 1
|
|
17
|
-
self.bit_type = None
|
|
18
|
-
|
|
19
|
-
def add_bit_noise(self, tensor_obj):
|
|
20
|
-
"""
|
|
21
|
-
对输入添加噪声
|
|
22
|
-
"""
|
|
23
|
-
# finfo应该列入黑名单
|
|
24
|
-
|
|
25
|
-
if isinstance(tensor_obj, torch.Tensor):
|
|
26
|
-
self._set_perturbation_bit(tensor_obj)
|
|
27
|
-
if not self.pre_check(tensor_obj):
|
|
28
|
-
return tensor_obj
|
|
29
|
-
sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
|
|
30
|
-
noise = TorchC.full(
|
|
31
|
-
tensor_obj.shape,
|
|
32
|
-
self.bit_tail,
|
|
33
|
-
device=tensor_obj.device,
|
|
34
|
-
dtype=self.bit_type,
|
|
35
|
-
)
|
|
36
|
-
result = tensor_obj.view(self.bit_type)
|
|
37
|
-
result = TorchC.where(
|
|
38
|
-
TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
|
|
39
|
-
self.bit_mode(result, noise),
|
|
40
|
-
result,
|
|
41
|
-
).view(tensor_obj.dtype)
|
|
42
|
-
|
|
43
|
-
self.is_added = True
|
|
44
|
-
return result
|
|
45
|
-
if isinstance(tensor_obj, dict):
|
|
46
|
-
return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
|
|
47
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
48
|
-
return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
|
|
49
|
-
return tensor_obj
|
|
50
|
-
|
|
51
|
-
def handle(self, params: DataParams):
|
|
52
|
-
"""
|
|
53
|
-
对输入添加扰动并返回
|
|
54
|
-
"""
|
|
55
|
-
logger.info_on_rank_0(
|
|
56
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
57
|
-
f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
|
|
58
|
-
)
|
|
59
|
-
params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
|
|
60
|
-
return self.perturbed_result(params)
|
|
61
|
-
|
|
62
|
-
def _check_details(self, tensor_obj):
|
|
63
|
-
"""
|
|
64
|
-
判断是否需要添加扰动, bit翻转
|
|
65
|
-
"""
|
|
66
|
-
if not self.bit_type:
|
|
67
|
-
logger.info_on_rank_0(
|
|
68
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
69
|
-
f"dtype unsupported. Cancel perturbation."
|
|
70
|
-
)
|
|
71
|
-
return False
|
|
72
|
-
if tensor_obj.numel() == 0:
|
|
73
|
-
logger.warning_on_rank_0(
|
|
74
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
|
|
75
|
-
f" Cancel adding noise."
|
|
76
|
-
)
|
|
77
|
-
return False
|
|
78
|
-
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
79
|
-
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
80
|
-
)
|
|
81
|
-
try:
|
|
82
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
83
|
-
except Exception:
|
|
84
|
-
logger.warning_on_rank_0(
|
|
85
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
86
|
-
f"when calculate maximun value, tensor is changed to float32."
|
|
87
|
-
)
|
|
88
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
89
|
-
if max_val < abs_tol:
|
|
90
|
-
logger.info_on_rank_0(
|
|
91
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
92
|
-
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
93
|
-
)
|
|
94
|
-
return False
|
|
95
|
-
return True
|
|
96
|
-
|
|
97
|
-
def _set_perturbation_bit(self, tensor_obj):
|
|
98
|
-
"""
|
|
99
|
-
根据不同浮点数确定不同位数扰动值
|
|
100
|
-
"""
|
|
101
|
-
bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
|
|
102
|
-
if bit_len_type:
|
|
103
|
-
self.bit_tail = 1
|
|
104
|
-
self.bit_type = bit_len_type
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BitNoiseLayer(NpuBaseLayer):
|
|
13
|
+
def __init__(self, api_name):
|
|
14
|
+
super().__init__(api_name)
|
|
15
|
+
self.bit_mode = TorchC.bitwise_xor
|
|
16
|
+
self.bit_tail: int = 1
|
|
17
|
+
self.bit_type = None
|
|
18
|
+
|
|
19
|
+
def add_bit_noise(self, tensor_obj):
|
|
20
|
+
"""
|
|
21
|
+
对输入添加噪声
|
|
22
|
+
"""
|
|
23
|
+
# finfo应该列入黑名单
|
|
24
|
+
|
|
25
|
+
if isinstance(tensor_obj, torch.Tensor):
|
|
26
|
+
self._set_perturbation_bit(tensor_obj)
|
|
27
|
+
if not self.pre_check(tensor_obj):
|
|
28
|
+
return tensor_obj
|
|
29
|
+
sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
|
|
30
|
+
noise = TorchC.full(
|
|
31
|
+
tensor_obj.shape,
|
|
32
|
+
self.bit_tail,
|
|
33
|
+
device=tensor_obj.device,
|
|
34
|
+
dtype=self.bit_type,
|
|
35
|
+
)
|
|
36
|
+
result = tensor_obj.view(self.bit_type)
|
|
37
|
+
result = TorchC.where(
|
|
38
|
+
TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
|
|
39
|
+
self.bit_mode(result, noise),
|
|
40
|
+
result,
|
|
41
|
+
).view(tensor_obj.dtype)
|
|
42
|
+
|
|
43
|
+
self.is_added = True
|
|
44
|
+
return result
|
|
45
|
+
if isinstance(tensor_obj, dict):
|
|
46
|
+
return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
|
|
47
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
48
|
+
return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
|
|
49
|
+
return tensor_obj
|
|
50
|
+
|
|
51
|
+
def handle(self, params: DataParams):
|
|
52
|
+
"""
|
|
53
|
+
对输入添加扰动并返回
|
|
54
|
+
"""
|
|
55
|
+
logger.info_on_rank_0(
|
|
56
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
57
|
+
f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
|
|
58
|
+
)
|
|
59
|
+
params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
|
|
60
|
+
return self.perturbed_result(params)
|
|
61
|
+
|
|
62
|
+
def _check_details(self, tensor_obj):
|
|
63
|
+
"""
|
|
64
|
+
判断是否需要添加扰动, bit翻转
|
|
65
|
+
"""
|
|
66
|
+
if not self.bit_type:
|
|
67
|
+
logger.info_on_rank_0(
|
|
68
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
69
|
+
f"dtype unsupported. Cancel perturbation."
|
|
70
|
+
)
|
|
71
|
+
return False
|
|
72
|
+
if tensor_obj.numel() == 0:
|
|
73
|
+
logger.warning_on_rank_0(
|
|
74
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
|
|
75
|
+
f" Cancel adding noise."
|
|
76
|
+
)
|
|
77
|
+
return False
|
|
78
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
79
|
+
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
80
|
+
)
|
|
81
|
+
try:
|
|
82
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
83
|
+
except Exception:
|
|
84
|
+
logger.warning_on_rank_0(
|
|
85
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
86
|
+
f"when calculate maximun value, tensor is changed to float32."
|
|
87
|
+
)
|
|
88
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
89
|
+
if max_val < abs_tol:
|
|
90
|
+
logger.info_on_rank_0(
|
|
91
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
92
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
93
|
+
)
|
|
94
|
+
return False
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
def _set_perturbation_bit(self, tensor_obj):
|
|
98
|
+
"""
|
|
99
|
+
根据不同浮点数确定不同位数扰动值
|
|
100
|
+
"""
|
|
101
|
+
bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
|
|
102
|
+
if bit_len_type:
|
|
103
|
+
self.bit_tail = 1
|
|
104
|
+
self.bit_type = bit_len_type
|
|
@@ -1,63 +1,63 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
6
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
7
|
-
NpuBaseLayer,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ChangeValueLayer(NpuBaseLayer):
|
|
12
|
-
def __init__(self, api_name):
|
|
13
|
-
super().__init__(api_name)
|
|
14
|
-
self.head: int = 0
|
|
15
|
-
self.tail: int = -1
|
|
16
|
-
|
|
17
|
-
def change_value(self, tensor_obj):
|
|
18
|
-
"""
|
|
19
|
-
交换张量首尾
|
|
20
|
-
"""
|
|
21
|
-
if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
|
|
22
|
-
new_tensor = TorchC.clone(tensor_obj)
|
|
23
|
-
if new_tensor.ndim == 1:
|
|
24
|
-
temp_first = TorchC.clone(new_tensor[self.head])
|
|
25
|
-
temp_last = TorchC.clone(new_tensor[self.tail])
|
|
26
|
-
new_tensor[self.head] = temp_last
|
|
27
|
-
new_tensor[self.tail] = temp_first
|
|
28
|
-
else:
|
|
29
|
-
temp_first = TorchC.clone(new_tensor[self.head][self.head])
|
|
30
|
-
temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
|
|
31
|
-
new_tensor[self.head][self.head] = temp_last
|
|
32
|
-
new_tensor[self.tail][self.tail] = temp_first
|
|
33
|
-
|
|
34
|
-
self.is_added = True
|
|
35
|
-
return new_tensor
|
|
36
|
-
if isinstance(tensor_obj, dict):
|
|
37
|
-
return {key: self.change_value(value) for key, value in tensor_obj.items()}
|
|
38
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
39
|
-
return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
|
|
40
|
-
return tensor_obj
|
|
41
|
-
|
|
42
|
-
def handle(self, params: DataParams):
|
|
43
|
-
"""
|
|
44
|
-
对输入添加扰动并返回
|
|
45
|
-
"""
|
|
46
|
-
logger.info_on_rank_0(
|
|
47
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
48
|
-
f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
|
|
49
|
-
)
|
|
50
|
-
params.perturbed_value = self.change_value(params.args[params.valid_input_index])
|
|
51
|
-
return self.perturbed_result(params)
|
|
52
|
-
|
|
53
|
-
def _check_details(self, tensor_obj):
|
|
54
|
-
"""
|
|
55
|
-
判断是否需要添加扰动, 首尾值交换
|
|
56
|
-
"""
|
|
57
|
-
if tensor_obj.size(0) < 2:
|
|
58
|
-
logger.info_on_rank_0(
|
|
59
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
60
|
-
f"size 0 must greater than 1. Cancel change value."
|
|
61
|
-
)
|
|
62
|
-
return False
|
|
63
|
-
return True
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
6
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
7
|
+
NpuBaseLayer,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ChangeValueLayer(NpuBaseLayer):
|
|
12
|
+
def __init__(self, api_name):
|
|
13
|
+
super().__init__(api_name)
|
|
14
|
+
self.head: int = 0
|
|
15
|
+
self.tail: int = -1
|
|
16
|
+
|
|
17
|
+
def change_value(self, tensor_obj):
|
|
18
|
+
"""
|
|
19
|
+
交换张量首尾
|
|
20
|
+
"""
|
|
21
|
+
if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
|
|
22
|
+
new_tensor = TorchC.clone(tensor_obj)
|
|
23
|
+
if new_tensor.ndim == 1:
|
|
24
|
+
temp_first = TorchC.clone(new_tensor[self.head])
|
|
25
|
+
temp_last = TorchC.clone(new_tensor[self.tail])
|
|
26
|
+
new_tensor[self.head] = temp_last
|
|
27
|
+
new_tensor[self.tail] = temp_first
|
|
28
|
+
else:
|
|
29
|
+
temp_first = TorchC.clone(new_tensor[self.head][self.head])
|
|
30
|
+
temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
|
|
31
|
+
new_tensor[self.head][self.head] = temp_last
|
|
32
|
+
new_tensor[self.tail][self.tail] = temp_first
|
|
33
|
+
|
|
34
|
+
self.is_added = True
|
|
35
|
+
return new_tensor
|
|
36
|
+
if isinstance(tensor_obj, dict):
|
|
37
|
+
return {key: self.change_value(value) for key, value in tensor_obj.items()}
|
|
38
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
39
|
+
return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
|
|
40
|
+
return tensor_obj
|
|
41
|
+
|
|
42
|
+
def handle(self, params: DataParams):
|
|
43
|
+
"""
|
|
44
|
+
对输入添加扰动并返回
|
|
45
|
+
"""
|
|
46
|
+
logger.info_on_rank_0(
|
|
47
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
48
|
+
f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
|
|
49
|
+
)
|
|
50
|
+
params.perturbed_value = self.change_value(params.args[params.valid_input_index])
|
|
51
|
+
return self.perturbed_result(params)
|
|
52
|
+
|
|
53
|
+
def _check_details(self, tensor_obj):
|
|
54
|
+
"""
|
|
55
|
+
判断是否需要添加扰动, 首尾值交换
|
|
56
|
+
"""
|
|
57
|
+
if tensor_obj.size(0) < 2:
|
|
58
|
+
logger.info_on_rank_0(
|
|
59
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
60
|
+
f"size 0 must greater than 1. Cancel change value."
|
|
61
|
+
)
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
@@ -1,68 +1,68 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from msprobe.core.common.const import Const
|
|
3
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
7
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
-
NpuBaseLayer,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
13
|
-
|
|
14
|
-
def improve_tensor_precision(self, tensor_obj):
|
|
15
|
-
if (
|
|
16
|
-
isinstance(tensor_obj, torch.Tensor)
|
|
17
|
-
and torch.is_floating_point(tensor_obj)
|
|
18
|
-
and tensor_obj.dtype not in [torch.float32, torch.float64]
|
|
19
|
-
):
|
|
20
|
-
self._set_improve_values(tensor_obj)
|
|
21
|
-
tensor_obj = self._change_dtype(tensor_obj)
|
|
22
|
-
self.is_added = True
|
|
23
|
-
return tensor_obj
|
|
24
|
-
if isinstance(tensor_obj, dict):
|
|
25
|
-
return {
|
|
26
|
-
key: self.improve_tensor_precision(value)
|
|
27
|
-
for key, value in tensor_obj.items()
|
|
28
|
-
}
|
|
29
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
30
|
-
return type(tensor_obj)(
|
|
31
|
-
[self.improve_tensor_precision(value) for value in tensor_obj]
|
|
32
|
-
)
|
|
33
|
-
return tensor_obj
|
|
34
|
-
|
|
35
|
-
def handle(self, params: DataParams):
|
|
36
|
-
logger.info_on_rank_0(
|
|
37
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
38
|
-
f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
|
|
39
|
-
)
|
|
40
|
-
new_args = self.improve_tensor_precision(params.args)
|
|
41
|
-
if params.fuzz_stage == Const.BACKWARD:
|
|
42
|
-
new_kwargs = {}
|
|
43
|
-
else:
|
|
44
|
-
new_kwargs = self.improve_tensor_precision(params.kwargs)
|
|
45
|
-
# 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
|
|
46
|
-
if not self.is_added:
|
|
47
|
-
return params.perturbed_result
|
|
48
|
-
if "inplace" in new_kwargs:
|
|
49
|
-
new_kwargs["inplace"] = False
|
|
50
|
-
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
51
|
-
return params.perturbed_result
|
|
52
|
-
|
|
53
|
-
def _set_improve_values(self, inputs):
|
|
54
|
-
if inputs.dtype in [torch.float16, torch.bfloat16]:
|
|
55
|
-
self.perturbed_value = torch.float32
|
|
56
|
-
|
|
57
|
-
def _change_dtype(self, inputs):
|
|
58
|
-
if hasattr(inputs, CommonField.DEVICE):
|
|
59
|
-
device = inputs.device
|
|
60
|
-
if device is CommonField.META:
|
|
61
|
-
new_inputs = inputs.to(
|
|
62
|
-
device=CommonField.META, dtype=self.perturbed_value
|
|
63
|
-
)
|
|
64
|
-
else:
|
|
65
|
-
new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
|
|
66
|
-
else:
|
|
67
|
-
new_inputs = inputs.to(dtype=self.perturbed_value)
|
|
68
|
-
return new_inputs
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
13
|
+
|
|
14
|
+
def improve_tensor_precision(self, tensor_obj):
|
|
15
|
+
if (
|
|
16
|
+
isinstance(tensor_obj, torch.Tensor)
|
|
17
|
+
and torch.is_floating_point(tensor_obj)
|
|
18
|
+
and tensor_obj.dtype not in [torch.float32, torch.float64]
|
|
19
|
+
):
|
|
20
|
+
self._set_improve_values(tensor_obj)
|
|
21
|
+
tensor_obj = self._change_dtype(tensor_obj)
|
|
22
|
+
self.is_added = True
|
|
23
|
+
return tensor_obj
|
|
24
|
+
if isinstance(tensor_obj, dict):
|
|
25
|
+
return {
|
|
26
|
+
key: self.improve_tensor_precision(value)
|
|
27
|
+
for key, value in tensor_obj.items()
|
|
28
|
+
}
|
|
29
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
30
|
+
return type(tensor_obj)(
|
|
31
|
+
[self.improve_tensor_precision(value) for value in tensor_obj]
|
|
32
|
+
)
|
|
33
|
+
return tensor_obj
|
|
34
|
+
|
|
35
|
+
def handle(self, params: DataParams):
|
|
36
|
+
logger.info_on_rank_0(
|
|
37
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
38
|
+
f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
|
|
39
|
+
)
|
|
40
|
+
new_args = self.improve_tensor_precision(params.args)
|
|
41
|
+
if params.fuzz_stage == Const.BACKWARD:
|
|
42
|
+
new_kwargs = {}
|
|
43
|
+
else:
|
|
44
|
+
new_kwargs = self.improve_tensor_precision(params.kwargs)
|
|
45
|
+
# 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
|
|
46
|
+
if not self.is_added:
|
|
47
|
+
return params.perturbed_result
|
|
48
|
+
if "inplace" in new_kwargs:
|
|
49
|
+
new_kwargs["inplace"] = False
|
|
50
|
+
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
51
|
+
return params.perturbed_result
|
|
52
|
+
|
|
53
|
+
def _set_improve_values(self, inputs):
|
|
54
|
+
if inputs.dtype in [torch.float16, torch.bfloat16]:
|
|
55
|
+
self.perturbed_value = torch.float32
|
|
56
|
+
|
|
57
|
+
def _change_dtype(self, inputs):
|
|
58
|
+
if hasattr(inputs, CommonField.DEVICE):
|
|
59
|
+
device = inputs.device
|
|
60
|
+
if device is CommonField.META:
|
|
61
|
+
new_inputs = inputs.to(
|
|
62
|
+
device=CommonField.META, dtype=self.perturbed_value
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
|
|
66
|
+
else:
|
|
67
|
+
new_inputs = inputs.to(dtype=self.perturbed_value)
|
|
68
|
+
return new_inputs
|