mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,30 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
17
|
+
from collections import defaultdict
|
|
2
18
|
|
|
3
19
|
from mindspore import Tensor
|
|
4
|
-
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
5
20
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
|
+
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
6
22
|
|
|
7
23
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
8
|
-
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
24
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
9
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
27
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
10
28
|
|
|
11
29
|
|
|
12
30
|
def dump_jit(name, in_feat, out_feat, is_forward):
|
|
@@ -17,19 +35,27 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
17
35
|
result = ori_args[0:index]
|
|
18
36
|
else:
|
|
19
37
|
result = "JitFunction"
|
|
20
|
-
if is_forward:
|
|
21
|
-
name_template = "Jit." + result + ".forward"
|
|
22
|
-
else:
|
|
23
|
-
name_template = "Jit." + result + ".backward"
|
|
24
38
|
if JitDump.need_dump():
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
39
|
+
if is_forward:
|
|
40
|
+
JitDump.jit_count[result] += 1
|
|
41
|
+
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
42
|
+
Const.FORWARD
|
|
43
|
+
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
44
|
+
module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
|
|
45
|
+
JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
|
|
46
|
+
else:
|
|
47
|
+
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
48
|
+
Const.BACKWARD
|
|
49
|
+
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
50
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat ,grad_output=out_feat)
|
|
51
|
+
JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
|
|
28
52
|
|
|
29
53
|
|
|
30
54
|
class JitDump(_MindsporeFunctionExecutor):
|
|
31
55
|
dump_config = None
|
|
32
56
|
jit_enable = False
|
|
57
|
+
jit_dump_switch = True
|
|
58
|
+
jit_count = defaultdict(int)
|
|
33
59
|
|
|
34
60
|
def __init__(self, *args, **kwargs):
|
|
35
61
|
super().__init__(*args, **kwargs)
|
|
@@ -38,11 +64,9 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
38
64
|
def __call__(self, *args, **kwargs):
|
|
39
65
|
api_register.api_set_ori_func()
|
|
40
66
|
out = super().__call__(*args, **kwargs)
|
|
41
|
-
if
|
|
42
|
-
dump_jit(
|
|
43
|
-
|
|
44
|
-
dump_jit(args[0], args[1:], out, True)
|
|
45
|
-
JitDump.jit_enable = True
|
|
67
|
+
if JitDump.jit_dump_switch and len(args) > 0:
|
|
68
|
+
dump_jit(args[0], args, out, True)
|
|
69
|
+
JitDump.jit_enable = True
|
|
46
70
|
api_register.api_set_hook_func()
|
|
47
71
|
return out
|
|
48
72
|
|
|
@@ -62,11 +86,11 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
62
86
|
return False
|
|
63
87
|
return True
|
|
64
88
|
|
|
65
|
-
def grad(self, obj, grad, weights, grad_position, *args,
|
|
66
|
-
if JitDump.jit_enable:
|
|
89
|
+
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
|
90
|
+
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
67
91
|
api_register.api_set_ori_func()
|
|
68
92
|
output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
|
69
|
-
if JitDump.jit_enable:
|
|
93
|
+
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
70
94
|
dump_jit(obj, args, None, False)
|
|
71
95
|
api_register.api_set_hook_func()
|
|
72
96
|
return output
|
|
@@ -1,8 +1,24 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import json
|
|
3
|
-
|
|
4
|
-
|
|
17
|
+
import os
|
|
18
|
+
|
|
5
19
|
from msprobe.core.common.file_utils import FileOpen, create_directory
|
|
20
|
+
from msprobe.mindspore.common.log import logger
|
|
21
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
6
22
|
|
|
7
23
|
|
|
8
24
|
class KernelGraphDump:
|
|
@@ -1,10 +1,25 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import json
|
|
17
|
+
import os
|
|
3
18
|
|
|
4
|
-
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
5
|
-
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.core.common.file_utils import FileOpen, create_directory
|
|
7
19
|
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory
|
|
21
|
+
from msprobe.mindspore.common.log import logger
|
|
22
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
class KernelKbykDump:
|
|
@@ -1,17 +1,32 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
import importlib
|
|
17
|
+
import inspect
|
|
18
|
+
import os
|
|
4
19
|
|
|
5
20
|
import mindspore as ms
|
|
6
21
|
from mindspore.communication import comm_func
|
|
7
22
|
|
|
8
|
-
from msprobe.core.common.file_utils import load_yaml, check_path_length
|
|
9
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.file_utils import check_path_length, load_yaml
|
|
10
25
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
11
26
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
12
|
-
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
13
27
|
from msprobe.mindspore.common.log import logger
|
|
14
28
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
29
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
15
30
|
from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
|
|
16
31
|
|
|
17
32
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Optional, Any, Tuple, Dict, Callable
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,14 +1,28 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Optional
|
|
4
18
|
|
|
5
19
|
import mindspore as ms
|
|
6
20
|
from mindspore import Tensor
|
|
7
21
|
|
|
8
|
-
from msprobe.mindspore.runtime import Runtime
|
|
9
22
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
10
|
-
from .config import Config
|
|
11
|
-
from .handler_params import HandlerParams
|
|
23
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
24
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
+
from msprobe.mindspore.runtime import Runtime
|
|
12
26
|
|
|
13
27
|
|
|
14
28
|
class Tools:
|
|
@@ -1,6 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.common.const import Const, FreeBenchmarkConst
|
|
1
17
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
2
|
-
from msprobe.mindspore.common.const import Const
|
|
3
|
-
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
4
18
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
5
19
|
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
6
20
|
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
@@ -1,16 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import sys
|
|
3
18
|
import traceback
|
|
4
19
|
from functools import wraps
|
|
5
|
-
from typing import
|
|
20
|
+
from typing import Dict, List, Tuple
|
|
6
21
|
|
|
7
22
|
from mindspore import ops
|
|
8
23
|
|
|
9
|
-
from msprobe.mindspore.runtime import Runtime
|
|
10
24
|
from msprobe.mindspore.common.log import logger
|
|
11
25
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
12
26
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
13
|
-
from .dec_forward import ForwardSelfChecker
|
|
27
|
+
from msprobe.mindspore.free_benchmark.decorator.dec_forward import ForwardSelfChecker
|
|
28
|
+
from msprobe.mindspore.runtime import Runtime
|
|
14
29
|
|
|
15
30
|
|
|
16
31
|
def decorate(original_func, decorate_func, api_name=None):
|
|
@@ -1,14 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Optional, Tuple
|
|
4
19
|
|
|
5
20
|
import mindspore as ms
|
|
6
21
|
from mindspore import Tensor, ops
|
|
7
22
|
|
|
8
|
-
from msprobe.mindspore.common.log import logger
|
|
9
|
-
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
10
23
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
11
25
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
26
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
12
27
|
|
|
13
28
|
|
|
14
29
|
class BaseHandler(ABC):
|
|
@@ -1,14 +1,29 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
from dataclasses import asdict
|
|
17
|
+
from typing import Any
|
|
3
18
|
|
|
4
19
|
from mindspore import Tensor, ops
|
|
5
20
|
|
|
21
|
+
from msprobe.core.data_dump.json_writer import DataWriter
|
|
6
22
|
from msprobe.mindspore.common.log import logger
|
|
7
23
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
8
|
-
from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
|
|
9
24
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
10
25
|
from msprobe.mindspore.free_benchmark.common.utils import make_unequal_row
|
|
11
|
-
from msprobe.
|
|
26
|
+
from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
|
|
12
27
|
|
|
13
28
|
|
|
14
29
|
class CheckHandler(BaseHandler):
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
from mindspore import Tensor
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
1
17
|
from msprobe.mindspore.common.log import logger
|
|
2
18
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
3
|
-
from msprobe.mindspore.
|
|
4
|
-
from .
|
|
5
|
-
from .fix_handler import FixHandler
|
|
19
|
+
from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler
|
|
20
|
+
from msprobe.mindspore.free_benchmark.handler.fix_handler import FixHandler
|
|
6
21
|
|
|
7
22
|
|
|
8
23
|
class HandlerFactory:
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
from mindspore import Tensor, ops
|
|
4
19
|
|
|
20
|
+
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
5
21
|
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
7
22
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
8
|
-
from msprobe.mindspore.
|
|
23
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
9
24
|
|
|
10
25
|
|
|
11
26
|
class AddNoisePerturbation(BasePerturbation):
|
|
@@ -43,25 +58,25 @@ class AddNoisePerturbation(BasePerturbation):
|
|
|
43
58
|
|
|
44
59
|
return inputs
|
|
45
60
|
|
|
46
|
-
def _get_noise(self,
|
|
61
|
+
def _get_noise(self, tensor):
|
|
47
62
|
"""
|
|
48
63
|
得到要添加的噪声值
|
|
49
64
|
|
|
50
65
|
"""
|
|
51
66
|
if self.is_fuzzed:
|
|
52
67
|
return False
|
|
53
|
-
if not ops.is_floating_point(
|
|
68
|
+
if not ops.is_floating_point(tensor) or ops.numel(tensor) == 0:
|
|
54
69
|
return False
|
|
55
70
|
|
|
56
|
-
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(
|
|
71
|
+
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
|
|
57
72
|
if not pert_value:
|
|
58
73
|
return False
|
|
59
74
|
else:
|
|
60
75
|
self.perturbation_value = pert_value
|
|
61
76
|
|
|
62
|
-
max_val = ops.max(ops.abs(
|
|
77
|
+
max_val = ops.max(ops.abs(tensor))[0].item()
|
|
63
78
|
if max_val < pert_value:
|
|
64
79
|
return False
|
|
65
80
|
|
|
66
|
-
noise = ops.full(
|
|
81
|
+
noise = ops.full(tensor.shape, self.perturbation_value, dtype=tensor.dtype)
|
|
67
82
|
return noise
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
@@ -1,10 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
4
19
|
from mindspore import Tensor, ops
|
|
5
20
|
|
|
6
|
-
from msprobe.mindspore.common.log import logger
|
|
7
21
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
22
|
+
from msprobe.mindspore.common.log import logger
|
|
8
23
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
9
24
|
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
10
25
|
|
|
@@ -45,19 +60,19 @@ class BitNoisePerturbation(BasePerturbation):
|
|
|
45
60
|
params.args = args
|
|
46
61
|
return self.get_fuzzed_result(params)
|
|
47
62
|
|
|
48
|
-
def _get_bit_len_type(self,
|
|
63
|
+
def _get_bit_len_type(self, tensor):
|
|
49
64
|
if self.is_fuzzed:
|
|
50
65
|
return False
|
|
51
|
-
if not isinstance(
|
|
52
|
-
|
|
66
|
+
if not isinstance(tensor, Tensor) or not ops.is_floating_point(tensor) or \
|
|
67
|
+
tensor.numel() == 0:
|
|
53
68
|
return False
|
|
54
|
-
bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(
|
|
69
|
+
bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(tensor.dtype)
|
|
55
70
|
if not bit_len_type:
|
|
56
71
|
return False
|
|
57
|
-
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(
|
|
72
|
+
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
|
|
58
73
|
if not pert_value:
|
|
59
74
|
return False
|
|
60
|
-
max_val = ops.max(ops.abs(
|
|
75
|
+
max_val = ops.max(ops.abs(tensor))[0].item()
|
|
61
76
|
if max_val < pert_value:
|
|
62
77
|
return False
|
|
63
78
|
return bit_len_type
|
|
@@ -1,14 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
|
-
from mindspore import Tensor
|
|
18
|
+
from mindspore import Tensor, ops
|
|
4
19
|
|
|
5
20
|
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
7
21
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
22
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
class ExchangeValuePerturbation(BasePerturbation):
|
|
11
26
|
|
|
27
|
+
@staticmethod
|
|
28
|
+
def _check_tensor_shape(inputs):
|
|
29
|
+
dims = len(inputs.shape)
|
|
30
|
+
if dims == 1 and inputs.shape[0] > 1:
|
|
31
|
+
return True
|
|
32
|
+
if dims > 1 and inputs.shape[1] > 0:
|
|
33
|
+
if inputs.shape[0] > 1 or inputs.shape[1] > 1:
|
|
34
|
+
return True
|
|
35
|
+
return False
|
|
36
|
+
|
|
12
37
|
def handle(self, params: HandlerParams) -> Any:
|
|
13
38
|
"""
|
|
14
39
|
返回首尾交换后的api输出
|
|
@@ -25,22 +50,23 @@ class ExchangeValuePerturbation(BasePerturbation):
|
|
|
25
50
|
返回首尾交换后的api输入
|
|
26
51
|
|
|
27
52
|
"""
|
|
28
|
-
if isinstance(inputs, Tensor):
|
|
29
|
-
if
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
53
|
+
if isinstance(inputs, Tensor) and ops.is_floating_point(inputs):
|
|
54
|
+
if self.is_fuzzed or not self._check_tensor_shape(inputs):
|
|
55
|
+
return inputs
|
|
56
|
+
result = inputs.copy()
|
|
57
|
+
if len(inputs.shape) == 1:
|
|
58
|
+
first_element = inputs[0]
|
|
59
|
+
last_element = inputs[-1]
|
|
60
|
+
result[0] = last_element
|
|
61
|
+
result[-1] = first_element
|
|
62
|
+
else:
|
|
63
|
+
first_element = inputs[0][0]
|
|
64
|
+
last_element = inputs[-1][-1]
|
|
65
|
+
result[0][0] = last_element
|
|
66
|
+
result[-1][-1] = first_element
|
|
67
|
+
|
|
68
|
+
self.is_fuzzed = True
|
|
69
|
+
return result
|
|
44
70
|
|
|
45
71
|
if isinstance(inputs, dict):
|
|
46
72
|
return {k: self.exchange_value(v) for k, v in inputs.items()}
|