mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AddNoiseLayer(NpuBaseLayer):
|
|
13
|
+
|
|
14
|
+
def add_noise(self, tensor_obj):
|
|
15
|
+
if isinstance(tensor_obj, torch.Tensor):
|
|
16
|
+
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
17
|
+
tensor_obj.dtype
|
|
18
|
+
)
|
|
19
|
+
if not self.pre_check(tensor_obj):
|
|
20
|
+
return tensor_obj
|
|
21
|
+
noise = self._get_noise(tensor_obj)
|
|
22
|
+
result = TorchC.where(
|
|
23
|
+
TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5),
|
|
24
|
+
TorchC.add(noise, tensor_obj),
|
|
25
|
+
tensor_obj,
|
|
26
|
+
).to(tensor_obj.dtype)
|
|
27
|
+
self.is_added = True
|
|
28
|
+
return result
|
|
29
|
+
if isinstance(tensor_obj, dict):
|
|
30
|
+
return {key: self.add_noise(value) for key, value in tensor_obj.items()}
|
|
31
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
32
|
+
return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
|
|
33
|
+
return tensor_obj
|
|
34
|
+
|
|
35
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
36
|
+
"""
|
|
37
|
+
对输入添加扰动并返回
|
|
38
|
+
"""
|
|
39
|
+
logger.info_on_rank_0(
|
|
40
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
41
|
+
f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
|
|
42
|
+
)
|
|
43
|
+
params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
|
|
44
|
+
return self.perturbed_result(params)
|
|
45
|
+
|
|
46
|
+
def _get_noise(self, tensor_obj):
|
|
47
|
+
dtype = tensor_obj.dtype
|
|
48
|
+
device = str(tensor_obj.device)
|
|
49
|
+
noise = TorchC.full(
|
|
50
|
+
tensor_obj.shape,
|
|
51
|
+
self.perturbed_value,
|
|
52
|
+
device=device,
|
|
53
|
+
dtype=dtype,
|
|
54
|
+
)
|
|
55
|
+
return noise
|
|
56
|
+
|
|
57
|
+
def _check_details(self, tensor_obj):
|
|
58
|
+
"""
|
|
59
|
+
判断是否需要添加扰动
|
|
60
|
+
"""
|
|
61
|
+
if not self.perturbed_value:
|
|
62
|
+
logger.warning_on_rank_0(
|
|
63
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
64
|
+
f"dtype unsupported. Cancel perturbation."
|
|
65
|
+
)
|
|
66
|
+
return False
|
|
67
|
+
if tensor_obj.numel() == 0:
|
|
68
|
+
logger.warning_on_rank_0(
|
|
69
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
70
|
+
f" Cancel adding noise."
|
|
71
|
+
)
|
|
72
|
+
return False
|
|
73
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
74
|
+
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
75
|
+
)
|
|
76
|
+
try:
|
|
77
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
78
|
+
except Exception:
|
|
79
|
+
logger.warning_on_rank_0(
|
|
80
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
81
|
+
f"when calculate maximun value, tensor is changed to float32."
|
|
82
|
+
)
|
|
83
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
84
|
+
if max_val < abs_tol:
|
|
85
|
+
logger.warning_on_rank_0(
|
|
86
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
87
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
88
|
+
)
|
|
89
|
+
return False
|
|
90
|
+
return True
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BitNoiseLayer(NpuBaseLayer):
|
|
13
|
+
def __init__(self, api_name):
|
|
14
|
+
super().__init__(api_name)
|
|
15
|
+
self.bit_mode = TorchC.bitwise_xor
|
|
16
|
+
self.bit_tail: int = 1
|
|
17
|
+
self.bit_type = None
|
|
18
|
+
|
|
19
|
+
def add_bit_noise(self, tensor_obj):
|
|
20
|
+
"""
|
|
21
|
+
对输入添加噪声
|
|
22
|
+
"""
|
|
23
|
+
# finfo应该列入黑名单
|
|
24
|
+
|
|
25
|
+
if isinstance(tensor_obj, torch.Tensor):
|
|
26
|
+
self._set_perturbation_bit(tensor_obj)
|
|
27
|
+
if not self.pre_check(tensor_obj):
|
|
28
|
+
return tensor_obj
|
|
29
|
+
sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
|
|
30
|
+
noise = TorchC.full(
|
|
31
|
+
tensor_obj.shape,
|
|
32
|
+
self.bit_tail,
|
|
33
|
+
device=tensor_obj.device,
|
|
34
|
+
dtype=self.bit_type,
|
|
35
|
+
)
|
|
36
|
+
result = tensor_obj.view(self.bit_type)
|
|
37
|
+
result = TorchC.where(
|
|
38
|
+
TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
|
|
39
|
+
self.bit_mode(result, noise),
|
|
40
|
+
result,
|
|
41
|
+
).view(tensor_obj.dtype)
|
|
42
|
+
|
|
43
|
+
self.is_added = True
|
|
44
|
+
return result
|
|
45
|
+
if isinstance(tensor_obj, dict):
|
|
46
|
+
return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
|
|
47
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
48
|
+
return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
|
|
49
|
+
return tensor_obj
|
|
50
|
+
|
|
51
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
52
|
+
"""
|
|
53
|
+
对输入添加扰动并返回
|
|
54
|
+
"""
|
|
55
|
+
logger.info_on_rank_0(
|
|
56
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
57
|
+
f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
|
|
58
|
+
)
|
|
59
|
+
params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
|
|
60
|
+
return self.perturbed_result(params)
|
|
61
|
+
|
|
62
|
+
def _check_details(self, tensor_obj):
|
|
63
|
+
"""
|
|
64
|
+
判断是否需要添加扰动, bit翻转
|
|
65
|
+
"""
|
|
66
|
+
if not self.bit_type:
|
|
67
|
+
logger.info_on_rank_0(
|
|
68
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
69
|
+
f"dtype unsupported. Cancel perturbation."
|
|
70
|
+
)
|
|
71
|
+
return False
|
|
72
|
+
if tensor_obj.numel() == 0:
|
|
73
|
+
logger.warning_on_rank_0(
|
|
74
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
|
|
75
|
+
f" Cancel adding noise."
|
|
76
|
+
)
|
|
77
|
+
return False
|
|
78
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
79
|
+
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
80
|
+
)
|
|
81
|
+
try:
|
|
82
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
83
|
+
except Exception:
|
|
84
|
+
logger.warning_on_rank_0(
|
|
85
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
86
|
+
f"when calculate maximun value, tensor is changed to float32."
|
|
87
|
+
)
|
|
88
|
+
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
89
|
+
if max_val < abs_tol:
|
|
90
|
+
logger.info_on_rank_0(
|
|
91
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
92
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
93
|
+
)
|
|
94
|
+
return False
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
def _set_perturbation_bit(self, tensor_obj):
|
|
98
|
+
"""
|
|
99
|
+
根据不同浮点数确定不同位数扰动值
|
|
100
|
+
"""
|
|
101
|
+
bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
|
|
102
|
+
if bit_len_type:
|
|
103
|
+
self.bit_tail = 1
|
|
104
|
+
self.bit_type = bit_len_type
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
6
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
7
|
+
NpuBaseLayer,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ChangeValueLayer(NpuBaseLayer):
|
|
12
|
+
def __init__(self, api_name):
|
|
13
|
+
super().__init__(api_name)
|
|
14
|
+
self.head: int = 0
|
|
15
|
+
self.tail: int = -1
|
|
16
|
+
|
|
17
|
+
def change_value(self, tensor_obj):
|
|
18
|
+
"""
|
|
19
|
+
交换张量首尾
|
|
20
|
+
"""
|
|
21
|
+
if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
|
|
22
|
+
new_tensor = TorchC.clone(tensor_obj)
|
|
23
|
+
if new_tensor.ndim == 1:
|
|
24
|
+
temp_first = TorchC.clone(new_tensor[self.head])
|
|
25
|
+
temp_last = TorchC.clone(new_tensor[self.tail])
|
|
26
|
+
new_tensor[self.head] = temp_last
|
|
27
|
+
new_tensor[self.tail] = temp_first
|
|
28
|
+
else:
|
|
29
|
+
temp_first = TorchC.clone(new_tensor[self.head][self.head])
|
|
30
|
+
temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
|
|
31
|
+
new_tensor[self.head][self.head] = temp_last
|
|
32
|
+
new_tensor[self.tail][self.tail] = temp_first
|
|
33
|
+
|
|
34
|
+
self.is_added = True
|
|
35
|
+
return new_tensor
|
|
36
|
+
if isinstance(tensor_obj, dict):
|
|
37
|
+
return {key: self.change_value(value) for key, value in tensor_obj.items()}
|
|
38
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
39
|
+
return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
|
|
40
|
+
return tensor_obj
|
|
41
|
+
|
|
42
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
43
|
+
"""
|
|
44
|
+
对输入添加扰动并返回
|
|
45
|
+
"""
|
|
46
|
+
logger.info_on_rank_0(
|
|
47
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
48
|
+
f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
|
|
49
|
+
)
|
|
50
|
+
params.perturbed_value = self.change_value(params.args[params.valid_input_index])
|
|
51
|
+
return self.perturbed_result(params)
|
|
52
|
+
|
|
53
|
+
def _check_details(self, tensor_obj):
|
|
54
|
+
"""
|
|
55
|
+
判断是否需要添加扰动, 首尾值交换
|
|
56
|
+
"""
|
|
57
|
+
if tensor_obj.size(0) < 2:
|
|
58
|
+
logger.info_on_rank_0(
|
|
59
|
+
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
60
|
+
f"size 0 must greater than 1. Cancel change value."
|
|
61
|
+
)
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
8
|
+
NpuBaseLayer,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
13
|
+
|
|
14
|
+
def improve_tensor_precision(self, tensor_obj):
|
|
15
|
+
if (
|
|
16
|
+
isinstance(tensor_obj, torch.Tensor)
|
|
17
|
+
and torch.is_floating_point(tensor_obj)
|
|
18
|
+
and tensor_obj.dtype not in [torch.float32, torch.float64]
|
|
19
|
+
):
|
|
20
|
+
self._set_improve_valus(tensor_obj)
|
|
21
|
+
tensor_obj = self._change_dtype(tensor_obj)
|
|
22
|
+
self.is_added = True
|
|
23
|
+
return tensor_obj
|
|
24
|
+
if isinstance(tensor_obj, dict):
|
|
25
|
+
return {
|
|
26
|
+
key: self.improve_tensor_precision(value)
|
|
27
|
+
for key, value in tensor_obj.items()
|
|
28
|
+
}
|
|
29
|
+
if isinstance(tensor_obj, (tuple, list)):
|
|
30
|
+
return type(tensor_obj)(
|
|
31
|
+
[self.improve_tensor_precision(value) for value in tensor_obj]
|
|
32
|
+
)
|
|
33
|
+
return tensor_obj
|
|
34
|
+
|
|
35
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
36
|
+
logger.info_on_rank_0(
|
|
37
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
38
|
+
f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
|
|
39
|
+
)
|
|
40
|
+
new_args = self.improve_tensor_precision(params.args)
|
|
41
|
+
if params.fuzz_stage == Const.BACKWARD:
|
|
42
|
+
new_kwargs = {}
|
|
43
|
+
else:
|
|
44
|
+
new_kwargs = self.improve_tensor_precision(params.kwargs)
|
|
45
|
+
# 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
|
|
46
|
+
if not self.is_added:
|
|
47
|
+
return params.perturbed_result
|
|
48
|
+
if "inplace" in new_kwargs:
|
|
49
|
+
new_kwargs["inplace"] = False
|
|
50
|
+
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
51
|
+
return params.perturbed_result
|
|
52
|
+
|
|
53
|
+
def _set_improve_valus(self, inputs):
|
|
54
|
+
if inputs.dtype in [torch.float16, torch.bfloat16]:
|
|
55
|
+
self.perturbed_value = torch.float32
|
|
56
|
+
|
|
57
|
+
def _change_dtype(self, inputs):
|
|
58
|
+
if hasattr(inputs, CommonField.DEVICE):
|
|
59
|
+
device = inputs.device
|
|
60
|
+
if device is CommonField.META:
|
|
61
|
+
new_inputs = inputs.to(
|
|
62
|
+
device=CommonField.META, dtype=self.perturbed_value
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
|
|
66
|
+
else:
|
|
67
|
+
new_inputs = inputs.to(dtype=self.perturbed_value)
|
|
68
|
+
return new_inputs
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
5
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
6
|
+
NpuBaseLayer,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NoChangeLayer(NpuBaseLayer):
|
|
11
|
+
|
|
12
|
+
def no_change(self, tensor_obj):
|
|
13
|
+
"""
|
|
14
|
+
不对输入做任何改变、直接二次执行
|
|
15
|
+
"""
|
|
16
|
+
self.is_added = True
|
|
17
|
+
return tensor_obj
|
|
18
|
+
|
|
19
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
20
|
+
"""
|
|
21
|
+
对输入添加扰动并返回
|
|
22
|
+
"""
|
|
23
|
+
logger.info_on_rank_0(
|
|
24
|
+
f"[msprobe] Free benchmark: Perturbation is "
|
|
25
|
+
f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
|
|
26
|
+
)
|
|
27
|
+
params.perturbed_value = self.no_change(params.args[params.valid_input_index])
|
|
28
|
+
return self.perturbed_result(params)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
6
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NpuBaseLayer(BaseLayer):
|
|
10
|
+
def __init__(self, api_name: str) -> None:
|
|
11
|
+
super().__init__(api_name)
|
|
12
|
+
self.perturbed_value = None # 扰动的元素
|
|
13
|
+
self.is_added = False # 标记当前算子输入是否调整
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def perturbed_result(params: DataParams) -> Any:
|
|
17
|
+
args_front = params.args[: params.valid_input_index]
|
|
18
|
+
args_rear = params.args[params.valid_input_index + 1:]
|
|
19
|
+
# 此处会将有inplace属性的算子换为非inplace
|
|
20
|
+
if "inplace" in params.kwargs:
|
|
21
|
+
params.kwargs["inplace"] = False
|
|
22
|
+
params.perturbed_result = params.origin_func(
|
|
23
|
+
*args_front, params.perturbed_value, *args_rear, **params.kwargs
|
|
24
|
+
)
|
|
25
|
+
return params.perturbed_result
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def handle(self, params: DataParams) -> Any:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def pre_check(self, tensor_obj):
|
|
32
|
+
"""
|
|
33
|
+
检查张量是否符合标准(float类型且最大值大于对应精度最小值)
|
|
34
|
+
"""
|
|
35
|
+
# 只针对第一个满足要求的添加扰动
|
|
36
|
+
if self.is_added:
|
|
37
|
+
return False
|
|
38
|
+
if not torch.is_floating_point(tensor_obj):
|
|
39
|
+
return False
|
|
40
|
+
if not self._check_details(tensor_obj):
|
|
41
|
+
return False
|
|
42
|
+
return True
|
|
43
|
+
|
|
44
|
+
def _check_details(self, tensor_obj):
|
|
45
|
+
return True
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
6
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CpuLayer(BaseLayer):
|
|
10
|
+
|
|
11
|
+
def handle(self, params: DataParams) -> torch.Any:
|
|
12
|
+
|
|
13
|
+
logger.info_on_rank_0(
|
|
14
|
+
f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
|
|
15
|
+
)
|
|
16
|
+
new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
|
|
17
|
+
new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
|
|
18
|
+
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
19
|
+
return params.perturbed_result
|
|
File without changes
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from msprobe.core.common.const import Const
|
|
7
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
8
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
9
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
10
|
+
FuzzThreshold,
|
|
11
|
+
NormType,
|
|
12
|
+
PerturbationMode,
|
|
13
|
+
)
|
|
14
|
+
from msprobe.pytorch.free_benchmark.common.params import (
|
|
15
|
+
DataParams,
|
|
16
|
+
HandlerParams,
|
|
17
|
+
make_unequal_row,
|
|
18
|
+
)
|
|
19
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FuzzHandler(ABC):
|
|
23
|
+
def __init__(self, params: HandlerParams) -> None:
|
|
24
|
+
self.params = params
|
|
25
|
+
self.unequal_rows = []
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def pre_process(origin_ouput, perturbed_output):
|
|
29
|
+
if (
|
|
30
|
+
isinstance(origin_ouput, tuple)
|
|
31
|
+
and hasattr(origin_ouput, "values")
|
|
32
|
+
and hasattr(origin_ouput, "indices")
|
|
33
|
+
):
|
|
34
|
+
origin_ouput = origin_ouput.values
|
|
35
|
+
perturbed_output = perturbed_output.values
|
|
36
|
+
if hasattr(perturbed_output, "dtype"):
|
|
37
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
|
|
38
|
+
else:
|
|
39
|
+
abs_tol = FuzzThreshold.F32_THD.value
|
|
40
|
+
return (
|
|
41
|
+
origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
|
|
42
|
+
perturbed_output,
|
|
43
|
+
abs_tol,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def convert_overflow_ratio_to_consistent(ratio):
|
|
48
|
+
if math.isnan(ratio) or math.isinf(ratio):
|
|
49
|
+
return ThresholdConfig.COMP_CONSISTENT
|
|
50
|
+
return ratio
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def get_threshold(self, dtype):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def get_ratio_from_specific_norm(
|
|
61
|
+
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
62
|
+
):
|
|
63
|
+
if norm_type == NormType.ENDLESS_NORM:
|
|
64
|
+
return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
|
|
65
|
+
return ThresholdConfig.COMP_CONSISTENT
|
|
66
|
+
|
|
67
|
+
def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
|
|
68
|
+
ratio_tensor1 = TorchC.where(
|
|
69
|
+
TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
|
|
70
|
+
TorchC.div(
|
|
71
|
+
TorchC.abs(origin_output),
|
|
72
|
+
TorchC.add(TorchC.abs(perturbed_output), abs_tol),
|
|
73
|
+
),
|
|
74
|
+
1,
|
|
75
|
+
)
|
|
76
|
+
ratio_tensor2 = TorchC.where(
|
|
77
|
+
TorchC.gt(TorchC.abs(origin_output), abs_tol),
|
|
78
|
+
TorchC.div(
|
|
79
|
+
TorchC.abs(perturbed_output),
|
|
80
|
+
TorchC.add(TorchC.abs(origin_output), abs_tol),
|
|
81
|
+
),
|
|
82
|
+
1,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
norm1 = self.convert_overflow_ratio_to_consistent(
|
|
86
|
+
TorchC.max(ratio_tensor1).item()
|
|
87
|
+
)
|
|
88
|
+
norm2 = self.convert_overflow_ratio_to_consistent(
|
|
89
|
+
TorchC.max(ratio_tensor2).item()
|
|
90
|
+
)
|
|
91
|
+
norm3 = self.convert_overflow_ratio_to_consistent(
|
|
92
|
+
TorchC.min(ratio_tensor1).item()
|
|
93
|
+
)
|
|
94
|
+
if norm3 < 0:
|
|
95
|
+
ratio = ThresholdConfig.SYMBOL_FLIPPING
|
|
96
|
+
else:
|
|
97
|
+
ratio = max(norm1, norm2)
|
|
98
|
+
return ratio
|
|
99
|
+
|
|
100
|
+
def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
|
|
101
|
+
try:
|
|
102
|
+
origin_output, perturbed_output, abs_tol = self.pre_process(
|
|
103
|
+
origin_output, perturbed_output
|
|
104
|
+
)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.warning_on_rank_0(
|
|
107
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
108
|
+
f"when computing ratio,"
|
|
109
|
+
f" y1 or y2 dtype is not supported {e}"
|
|
110
|
+
)
|
|
111
|
+
return ThresholdConfig.COMP_NAN
|
|
112
|
+
if self.params.fuzz_stage == Const.BACKWARD:
|
|
113
|
+
abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
|
|
114
|
+
else:
|
|
115
|
+
abs_tol = abs_tol ** 0.5
|
|
116
|
+
return self.get_ratio_from_specific_norm(
|
|
117
|
+
origin_output, perturbed_output, norm_type, abs_tol
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def npu_compare(
|
|
121
|
+
self, origin_output, perturbed_output
|
|
122
|
+
) -> Tuple[bool, Optional[float]]:
|
|
123
|
+
|
|
124
|
+
if isinstance(perturbed_output, int):
|
|
125
|
+
return origin_output == perturbed_output, None
|
|
126
|
+
elif isinstance(perturbed_output, float):
|
|
127
|
+
if perturbed_output == 0:
|
|
128
|
+
origin_output += FuzzThreshold.F32_THD
|
|
129
|
+
perturbed_output += FuzzThreshold.F32_THD
|
|
130
|
+
return (
|
|
131
|
+
math.isclose(origin_output, perturbed_output),
|
|
132
|
+
origin_output / perturbed_output,
|
|
133
|
+
)
|
|
134
|
+
elif not isinstance(perturbed_output, torch.Tensor):
|
|
135
|
+
logger.warning_on_rank_0(
|
|
136
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
137
|
+
f"The compare for output type {type(perturbed_output)} is not supported"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
|
|
141
|
+
ratio = self.ratio_calculate(
|
|
142
|
+
origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
|
|
143
|
+
)
|
|
144
|
+
if ratio == ThresholdConfig.SYMBOL_FLIPPING:
|
|
145
|
+
is_consistent = False
|
|
146
|
+
else:
|
|
147
|
+
is_consistent = threshold >= ratio >= 1 / threshold
|
|
148
|
+
return is_consistent, ratio
|
|
149
|
+
|
|
150
|
+
def cmp_output_npu(self, data_params: DataParams):
|
|
151
|
+
npu_consistent = True
|
|
152
|
+
max_fuzz_ratio = 0
|
|
153
|
+
try:
|
|
154
|
+
if isinstance(data_params.original_result, torch.Tensor):
|
|
155
|
+
is_consistent, ratio = self.npu_compare(
|
|
156
|
+
data_params.original_result, data_params.perturbed_result
|
|
157
|
+
)
|
|
158
|
+
npu_consistent = is_consistent
|
|
159
|
+
max_fuzz_ratio = (
|
|
160
|
+
max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
|
|
161
|
+
)
|
|
162
|
+
data_params.is_consistent = is_consistent and data_params.is_consistent
|
|
163
|
+
if not is_consistent and data_params.grad_unequal_flag:
|
|
164
|
+
self.unequal_rows.append(
|
|
165
|
+
make_unequal_row(data_params, self.params, ratio=ratio)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
elif isinstance(data_params.original_result, (list, tuple)):
|
|
169
|
+
for index_, origin_item in enumerate(data_params.original_result):
|
|
170
|
+
is_consistent, ratio = self.npu_compare(
|
|
171
|
+
origin_item, data_params.perturbed_result[index_]
|
|
172
|
+
)
|
|
173
|
+
npu_consistent = npu_consistent and is_consistent
|
|
174
|
+
max_fuzz_ratio = (
|
|
175
|
+
max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
|
|
176
|
+
)
|
|
177
|
+
data_params.is_consistent = (
|
|
178
|
+
is_consistent and data_params.is_consistent
|
|
179
|
+
)
|
|
180
|
+
if not is_consistent and data_params.grad_unequal_flag:
|
|
181
|
+
self.unequal_rows.append(
|
|
182
|
+
make_unequal_row(
|
|
183
|
+
data_params, self.params, ratio=ratio, index=index_
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.warning_on_rank_0(
|
|
188
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
189
|
+
f"when campare the result exception raise {e}"
|
|
190
|
+
)
|
|
191
|
+
return npu_consistent, max_fuzz_ratio
|
|
192
|
+
|
|
193
|
+
def get_unequal_rows(self):
|
|
194
|
+
return self.unequal_rows
|
|
195
|
+
|
|
196
|
+
def _get_default_threshold(self, dtype):
|
|
197
|
+
if self.params.pert_mode == PerturbationMode.NO_CHANGE:
|
|
198
|
+
threshold = ThresholdConfig.COMP_CONSISTENT
|
|
199
|
+
else:
|
|
200
|
+
threshold = ThresholdConfig.DTYPE_PER_THD.get(
|
|
201
|
+
dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
|
|
202
|
+
)
|
|
203
|
+
return threshold
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
7
|
+
from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
|
|
8
|
+
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CheckerHandler(FuzzHandler):
|
|
12
|
+
def other_compare(self, data_params: DataParams) -> bool:
|
|
13
|
+
is_consistent = SingleCompare().compare_seq(
|
|
14
|
+
data_params.original_result, data_params.perturbed_result
|
|
15
|
+
)
|
|
16
|
+
if not is_consistent:
|
|
17
|
+
self.unequal_rows.append(
|
|
18
|
+
make_unequal_row(data_params, self.params)
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def get_threshold(self, dtype):
|
|
22
|
+
return self._get_default_threshold(dtype)
|
|
23
|
+
|
|
24
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
25
|
+
if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
|
|
26
|
+
data_params.perturbed_result
|
|
27
|
+
):
|
|
28
|
+
return data_params.original_result
|
|
29
|
+
try:
|
|
30
|
+
if self.params.fuzz_device == DeviceType.NPU:
|
|
31
|
+
self.cmp_output_npu(data_params)
|
|
32
|
+
else:
|
|
33
|
+
self.other_compare(data_params)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.warning_on_rank_0(
|
|
36
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
37
|
+
f"when campare the result exception raise {e}"
|
|
38
|
+
)
|
|
39
|
+
return data_params.original_result
|