mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,34 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
|
|
1
18
|
import torch
|
|
2
|
-
from
|
|
3
|
-
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
4
|
-
from msprobe.pytorch.service import Service
|
|
5
|
-
from msprobe.pytorch.common.log import logger
|
|
6
|
-
from msprobe.pytorch.pt_config import parse_json_config
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
7
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
8
|
-
from msprobe.core.common.
|
|
21
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
22
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
9
25
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
|
+
from msprobe.pytorch.service import Service
|
|
28
|
+
from torch.utils.data import dataloader
|
|
29
|
+
|
|
30
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
|
|
31
|
+
"dump_path", "level", "model"])
|
|
10
32
|
|
|
11
33
|
|
|
12
34
|
class PrecisionDebugger:
|
|
@@ -30,20 +52,26 @@ class PrecisionDebugger:
|
|
|
30
52
|
step=None,
|
|
31
53
|
):
|
|
32
54
|
if not hasattr(self, "initialized"):
|
|
55
|
+
config_params = ConfigParameters(config_path,
|
|
56
|
+
task,
|
|
57
|
+
dump_path,
|
|
58
|
+
level,
|
|
59
|
+
model)
|
|
60
|
+
self.check_input_params(config_params)
|
|
61
|
+
|
|
33
62
|
self.api_origin = False
|
|
34
63
|
self.initialized = True
|
|
35
|
-
self.model =
|
|
64
|
+
self.model = model
|
|
36
65
|
common_config, task_config = parse_json_config(config_path, task)
|
|
37
|
-
self.task = common_config.task
|
|
66
|
+
self.task = task if task else common_config.task
|
|
38
67
|
if self.task == Const.GRAD_PROBE:
|
|
39
68
|
self.gm = GradientMonitor(common_config, task_config)
|
|
40
69
|
return
|
|
41
70
|
if step:
|
|
42
|
-
common_config.step = step
|
|
71
|
+
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
43
72
|
self.config = DebuggerConfig(
|
|
44
73
|
common_config, task_config, task, dump_path, level
|
|
45
74
|
)
|
|
46
|
-
self.config.check_model(self.model)
|
|
47
75
|
self.service = Service(self.config)
|
|
48
76
|
self.enable_dataloader = self.config.enable_dataloader
|
|
49
77
|
if self.enable_dataloader:
|
|
@@ -55,20 +83,40 @@ class PrecisionDebugger:
|
|
|
55
83
|
return self._instance
|
|
56
84
|
|
|
57
85
|
@staticmethod
|
|
58
|
-
def
|
|
59
|
-
if
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
86
|
+
def check_input_params(args):
|
|
87
|
+
if args.config_path is not None:
|
|
88
|
+
if not isinstance(args.config_path, str):
|
|
89
|
+
raise MsprobeException(
|
|
90
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
91
|
+
file_checker = FileChecker(
|
|
92
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
93
|
+
file_checker.common_check()
|
|
94
|
+
|
|
95
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
96
|
+
raise MsprobeException(
|
|
97
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
98
|
+
|
|
99
|
+
if args.dump_path is not None:
|
|
100
|
+
if not isinstance(args.dump_path, str):
|
|
101
|
+
raise MsprobeException(
|
|
102
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
103
|
+
|
|
104
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
105
|
+
raise MsprobeException(
|
|
106
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
|
+
|
|
108
|
+
if args.model is not None and not isinstance(args.model, torch.nn.Module):
|
|
109
|
+
raise MsprobeException(
|
|
110
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|
|
64
111
|
|
|
65
112
|
@classmethod
|
|
66
|
-
def start(cls):
|
|
113
|
+
def start(cls, model=None):
|
|
67
114
|
instance = cls._instance
|
|
115
|
+
if not instance:
|
|
116
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
68
117
|
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
69
118
|
return
|
|
70
|
-
|
|
71
|
-
raise Exception("No instance of PrecisionDebugger found.")
|
|
119
|
+
instance.config.check_model(instance, model)
|
|
72
120
|
if instance.enable_dataloader:
|
|
73
121
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
74
122
|
else:
|
|
@@ -85,10 +133,10 @@ class PrecisionDebugger:
|
|
|
85
133
|
@classmethod
|
|
86
134
|
def stop(cls):
|
|
87
135
|
instance = cls._instance
|
|
136
|
+
if not instance:
|
|
137
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
88
138
|
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
89
139
|
return
|
|
90
|
-
if not instance:
|
|
91
|
-
raise Exception("PrecisionDebugger instance is not created.")
|
|
92
140
|
if instance.enable_dataloader:
|
|
93
141
|
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
94
142
|
else:
|
|
@@ -96,16 +144,16 @@ class PrecisionDebugger:
|
|
|
96
144
|
|
|
97
145
|
@classmethod
|
|
98
146
|
def step(cls):
|
|
147
|
+
if not cls._instance:
|
|
148
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
99
149
|
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
100
150
|
return
|
|
101
|
-
if not cls._instance:
|
|
102
|
-
raise Exception("PrecisionDebugger instance is not created.")
|
|
103
151
|
cls._instance.service.step()
|
|
104
152
|
|
|
105
153
|
@classmethod
|
|
106
154
|
def monitor(cls, model):
|
|
107
155
|
if not cls._instance:
|
|
108
|
-
raise Exception(
|
|
156
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
109
157
|
if cls._instance.task != Const.GRAD_PROBE:
|
|
110
158
|
return
|
|
111
159
|
cls._instance.gm.monitor(model)
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
__all__ = ["FreeBenchmarkCheck", "UnequalRow"]
|
|
17
|
+
|
|
3
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
4
21
|
|
|
5
|
-
from .main import FreeBenchmarkCheck
|
|
6
22
|
from .common.params import UnequalRow
|
|
7
|
-
|
|
8
|
-
__all__ = [FreeBenchmarkCheck, UnequalRow]
|
|
23
|
+
from .main import FreeBenchmarkCheck
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from msprobe.core.common.const import Const
|
|
2
|
+
|
|
3
|
+
|
|
1
4
|
class PerturbationMode:
|
|
2
5
|
ADD_NOISE = "add_noise"
|
|
3
6
|
CHANGE_VALUE = "change_value"
|
|
@@ -35,3 +38,28 @@ class FuzzLevel:
|
|
|
35
38
|
BASE_LEVEL = "L1"
|
|
36
39
|
ADV_LEVEL = "L2"
|
|
37
40
|
REAL_LEVEL = "L3"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PytorchFreeBenchmarkConst:
|
|
44
|
+
PERTURBATION_MODE_LIST = [
|
|
45
|
+
PerturbationMode.ADD_NOISE,
|
|
46
|
+
PerturbationMode.CHANGE_VALUE,
|
|
47
|
+
PerturbationMode.IMPROVE_PRECISION,
|
|
48
|
+
PerturbationMode.NO_CHANGE,
|
|
49
|
+
PerturbationMode.BIT_NOISE,
|
|
50
|
+
PerturbationMode.TO_CPU,
|
|
51
|
+
]
|
|
52
|
+
DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION
|
|
53
|
+
DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU]
|
|
54
|
+
DEFAULT_DEVICE = DeviceType.NPU
|
|
55
|
+
HANDLER_LIST = [HandlerType.CHECK, HandlerType.FIX]
|
|
56
|
+
DEFAULT_HANDLER = HandlerType.CHECK
|
|
57
|
+
FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL]
|
|
58
|
+
DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL
|
|
59
|
+
FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
60
|
+
FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU]
|
|
61
|
+
DEFAULT_FUZZ_STAGE = Const.FORWARD
|
|
62
|
+
DEFAULT_PREHEAT_STEP = 15
|
|
63
|
+
DEFAULT_MAX_SAMPLE = 20
|
|
64
|
+
CPU_MODE_LIST = [PerturbationMode.TO_CPU]
|
|
65
|
+
FIX_STAGE_LIST = [Const.FORWARD]
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from dataclasses import dataclass
|
|
2
17
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
3
18
|
|
|
@@ -75,7 +90,8 @@ class Tools:
|
|
|
75
90
|
)
|
|
76
91
|
return type(origin)(result)
|
|
77
92
|
return origin
|
|
78
|
-
|
|
93
|
+
|
|
94
|
+
|
|
79
95
|
class TorchC:
|
|
80
96
|
sum = torch._C._VariableFunctionsClass.sum
|
|
81
97
|
isinf = torch._C._VariableFunctionsClass.isinf
|
|
@@ -1,8 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
3
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
4
19
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.params import
|
|
20
|
+
from msprobe.pytorch.free_benchmark.common.params import (
|
|
21
|
+
DataParams,
|
|
22
|
+
HandlerParams,
|
|
23
|
+
data_pre_deal,
|
|
24
|
+
)
|
|
6
25
|
from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
|
|
7
26
|
from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
|
|
8
27
|
FuzzHandlerFactory,
|
|
@@ -84,7 +103,7 @@ class GradSaver:
|
|
|
84
103
|
if self.perturbed_grad_input is None:
|
|
85
104
|
raise FreeBenchmarkException(
|
|
86
105
|
FreeBenchmarkException.InvalidGrad,
|
|
87
|
-
f"grad not exists : {self.api_name}."
|
|
106
|
+
f"grad not exists : {self.api_name}.",
|
|
88
107
|
)
|
|
89
108
|
with torch.no_grad():
|
|
90
109
|
perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
|
|
@@ -94,7 +113,7 @@ class GradSaver:
|
|
|
94
113
|
raise FreeBenchmarkException(
|
|
95
114
|
FreeBenchmarkException.InvalidGrad,
|
|
96
115
|
f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
|
|
97
|
-
f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
|
|
116
|
+
f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
|
|
98
117
|
)
|
|
99
118
|
return perturbed_grad
|
|
100
119
|
|
|
@@ -150,8 +169,8 @@ class GradSaver:
|
|
|
150
169
|
else:
|
|
151
170
|
_real_input.append(object_)
|
|
152
171
|
kwargs = self.kwargs.copy()
|
|
153
|
-
if
|
|
154
|
-
kwargs[
|
|
172
|
+
if "inplace" in kwargs:
|
|
173
|
+
kwargs["inplace"] = False
|
|
155
174
|
return self.origin_func(*_real_input, **kwargs)
|
|
156
175
|
|
|
157
176
|
_, grad_input = torch.autograd.functional.vjp(
|
|
@@ -159,12 +178,14 @@ class GradSaver:
|
|
|
159
178
|
)
|
|
160
179
|
return grad_input
|
|
161
180
|
|
|
162
|
-
def calculate_perturbed_grad_input(
|
|
181
|
+
def calculate_perturbed_grad_input(
|
|
182
|
+
self, grad_output, need_grad_tensors, inner_args
|
|
183
|
+
):
|
|
163
184
|
data_params = data_pre_deal(
|
|
164
185
|
self.handler_params.api_name,
|
|
165
186
|
self.get_grad_input_from_vjp,
|
|
166
187
|
[need_grad_tensors, grad_output, inner_args],
|
|
167
|
-
{}
|
|
188
|
+
{},
|
|
168
189
|
)
|
|
169
190
|
layer = LayerFactory.create(
|
|
170
191
|
self.handler_params.api_name,
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
|
|
3
18
|
import torch
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC
|
|
2
17
|
|
|
3
18
|
import torch
|
|
@@ -36,9 +51,9 @@ class FreeBenchmarkCheck(ABC):
|
|
|
36
51
|
|
|
37
52
|
def update_iter(self, update_iter):
|
|
38
53
|
self.current_iter = update_iter
|
|
39
|
-
|
|
54
|
+
|
|
40
55
|
def if_fix(self):
|
|
41
|
-
if self.config.handler_type==HandlerType.FIX:
|
|
56
|
+
if self.config.handler_type == HandlerType.FIX:
|
|
42
57
|
return True
|
|
43
58
|
return False
|
|
44
59
|
|
|
@@ -73,9 +88,9 @@ class FreeBenchmarkCheck(ABC):
|
|
|
73
88
|
layer.handle(data_params)
|
|
74
89
|
handler_params = make_handler_params(name, self.config, self.current_iter)
|
|
75
90
|
handler = FuzzHandlerFactory.create(handler_params)
|
|
76
|
-
perturbed_output = handler.handle(data_params)
|
|
91
|
+
perturbed_output = handler.handle(data_params)
|
|
77
92
|
return perturbed_output, handler.get_unequal_rows()
|
|
78
|
-
|
|
93
|
+
|
|
79
94
|
def backward(self, name, module, grad_output):
|
|
80
95
|
|
|
81
96
|
if not self.config.fuzz_stage == Const.BACKWARD:
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
17
|
from typing import Any
|
|
3
18
|
|
|
@@ -1,14 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkException
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
|
|
3
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
|
|
4
|
-
ImprovePrecisionLayer,
|
|
5
|
-
)
|
|
6
18
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
|
|
7
19
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
|
|
8
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
|
|
9
20
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
|
|
10
21
|
ChangeValueLayer,
|
|
11
22
|
)
|
|
23
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
|
|
24
|
+
ImprovePrecisionLayer,
|
|
25
|
+
)
|
|
26
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
|
|
12
27
|
from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
|
|
13
28
|
|
|
14
29
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -54,10 +69,19 @@ class ChangeValueLayer(NpuBaseLayer):
|
|
|
54
69
|
"""
|
|
55
70
|
判断是否需要添加扰动, 首尾值交换
|
|
56
71
|
"""
|
|
57
|
-
|
|
72
|
+
# 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
|
|
73
|
+
if tensor_obj.ndim > 1:
|
|
74
|
+
if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
|
|
75
|
+
logger.info_on_rank_0(
|
|
76
|
+
f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
|
|
77
|
+
f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
|
|
78
|
+
)
|
|
79
|
+
return False
|
|
80
|
+
# 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
|
|
81
|
+
elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
|
|
58
82
|
logger.info_on_rank_0(
|
|
59
83
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
60
|
-
f"
|
|
84
|
+
f"0-dimension must greater than 1. Cancel change value."
|
|
61
85
|
)
|
|
62
86
|
return False
|
|
63
87
|
return True
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.core.common.const import Const
|
|
3
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import abstractmethod
|
|
2
17
|
from typing import Any
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|