mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,8 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
3
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
4
19
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.params import
|
|
20
|
+
from msprobe.pytorch.free_benchmark.common.params import (
|
|
21
|
+
DataParams,
|
|
22
|
+
HandlerParams,
|
|
23
|
+
data_pre_deal,
|
|
24
|
+
)
|
|
6
25
|
from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
|
|
7
26
|
from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
|
|
8
27
|
FuzzHandlerFactory,
|
|
@@ -83,8 +102,13 @@ class GradSaver:
|
|
|
83
102
|
def check_grad_input(self, origin_grad, new_grad_index):
|
|
84
103
|
if self.perturbed_grad_input is None:
|
|
85
104
|
raise FreeBenchmarkException(
|
|
86
|
-
FreeBenchmarkException.
|
|
87
|
-
f"grad not exists
|
|
105
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
106
|
+
f"perturbed grad not exists for {self.api_name}.",
|
|
107
|
+
)
|
|
108
|
+
if len(self.perturbed_grad_input) <= new_grad_index:
|
|
109
|
+
raise FreeBenchmarkException(
|
|
110
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
111
|
+
f"perturbed grad index {new_grad_index} is out of bounds for {self.api_name}.",
|
|
88
112
|
)
|
|
89
113
|
with torch.no_grad():
|
|
90
114
|
perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
|
|
@@ -92,9 +116,9 @@ class GradSaver:
|
|
|
92
116
|
)
|
|
93
117
|
if origin_grad.shape != perturbed_grad.shape:
|
|
94
118
|
raise FreeBenchmarkException(
|
|
95
|
-
FreeBenchmarkException.
|
|
119
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
96
120
|
f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
|
|
97
|
-
f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
|
|
121
|
+
f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
|
|
98
122
|
)
|
|
99
123
|
return perturbed_grad
|
|
100
124
|
|
|
@@ -145,13 +169,25 @@ class GradSaver:
|
|
|
145
169
|
index_ = 0
|
|
146
170
|
for object_ in inner_args:
|
|
147
171
|
if object_ is CommonField.HOLD_PLACE:
|
|
172
|
+
if index_ >= len(inputs):
|
|
173
|
+
err_msg = (
|
|
174
|
+
f"[msprobe] Free benchmark: When getting input from vjp, "
|
|
175
|
+
f" the input index ({index_}) is out of bounds ({len(inputs)})."
|
|
176
|
+
)
|
|
177
|
+
logger.error_log_with_exp(
|
|
178
|
+
err_msg,
|
|
179
|
+
FreeBenchmarkException(
|
|
180
|
+
FreeBenchmarkException.InvalidGrad,
|
|
181
|
+
error_info=err_msg,
|
|
182
|
+
),
|
|
183
|
+
)
|
|
148
184
|
_real_input.append(inputs[index_])
|
|
149
185
|
index_ += 1
|
|
150
186
|
else:
|
|
151
187
|
_real_input.append(object_)
|
|
152
188
|
kwargs = self.kwargs.copy()
|
|
153
|
-
if
|
|
154
|
-
kwargs[
|
|
189
|
+
if "inplace" in kwargs:
|
|
190
|
+
kwargs["inplace"] = False
|
|
155
191
|
return self.origin_func(*_real_input, **kwargs)
|
|
156
192
|
|
|
157
193
|
_, grad_input = torch.autograd.functional.vjp(
|
|
@@ -159,12 +195,14 @@ class GradSaver:
|
|
|
159
195
|
)
|
|
160
196
|
return grad_input
|
|
161
197
|
|
|
162
|
-
def calculate_perturbed_grad_input(
|
|
198
|
+
def calculate_perturbed_grad_input(
|
|
199
|
+
self, grad_output, need_grad_tensors, inner_args
|
|
200
|
+
):
|
|
163
201
|
data_params = data_pre_deal(
|
|
164
202
|
self.handler_params.api_name,
|
|
165
203
|
self.get_grad_input_from_vjp,
|
|
166
204
|
[need_grad_tensors, grad_output, inner_args],
|
|
167
|
-
{}
|
|
205
|
+
{},
|
|
168
206
|
)
|
|
169
207
|
layer = LayerFactory.create(
|
|
170
208
|
self.handler_params.api_name,
|
|
@@ -1,6 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
|
|
3
18
|
import torch
|
|
19
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
4
20
|
from msprobe.pytorch.free_benchmark import logger
|
|
5
21
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
6
22
|
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
@@ -52,6 +68,7 @@ class SingleCompare:
|
|
|
52
68
|
return False
|
|
53
69
|
return True
|
|
54
70
|
|
|
71
|
+
@recursion_depth_decorator("FreeBenchmark: SingleCompare.compare_seq")
|
|
55
72
|
def compare_seq(self, actual, golden):
|
|
56
73
|
if isinstance(golden, torch.Tensor):
|
|
57
74
|
return self.compare_tensor_seq(actual, golden)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC
|
|
2
17
|
|
|
3
18
|
import torch
|
|
@@ -36,9 +51,9 @@ class FreeBenchmarkCheck(ABC):
|
|
|
36
51
|
|
|
37
52
|
def update_iter(self, update_iter):
|
|
38
53
|
self.current_iter = update_iter
|
|
39
|
-
|
|
54
|
+
|
|
40
55
|
def if_fix(self):
|
|
41
|
-
if self.config.handler_type==HandlerType.FIX:
|
|
56
|
+
if self.config.handler_type == HandlerType.FIX:
|
|
42
57
|
return True
|
|
43
58
|
return False
|
|
44
59
|
|
|
@@ -73,9 +88,9 @@ class FreeBenchmarkCheck(ABC):
|
|
|
73
88
|
layer.handle(data_params)
|
|
74
89
|
handler_params = make_handler_params(name, self.config, self.current_iter)
|
|
75
90
|
handler = FuzzHandlerFactory.create(handler_params)
|
|
76
|
-
perturbed_output = handler.handle(data_params)
|
|
91
|
+
perturbed_output = handler.handle(data_params)
|
|
77
92
|
return perturbed_output, handler.get_unequal_rows()
|
|
78
|
-
|
|
93
|
+
|
|
79
94
|
def backward(self, name, module, grad_output):
|
|
80
95
|
|
|
81
96
|
if not self.config.fuzz_stage == Const.BACKWARD:
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
17
|
from typing import Any
|
|
3
18
|
|
|
@@ -1,14 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkException
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
|
|
3
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
|
|
4
|
-
ImprovePrecisionLayer,
|
|
5
|
-
)
|
|
6
18
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
|
|
7
19
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
|
|
8
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
|
|
9
20
|
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
|
|
10
21
|
ChangeValueLayer,
|
|
11
22
|
)
|
|
23
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
|
|
24
|
+
ImprovePrecisionLayer,
|
|
25
|
+
)
|
|
26
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
|
|
12
27
|
from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
|
|
13
28
|
|
|
14
29
|
|
|
@@ -1,4 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
2
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -11,6 +27,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
|
|
|
11
27
|
|
|
12
28
|
class AddNoiseLayer(NpuBaseLayer):
|
|
13
29
|
|
|
30
|
+
@recursion_depth_decorator("FreeBenchmark: AddNoiseLayer.add_noise")
|
|
14
31
|
def add_noise(self, tensor_obj):
|
|
15
32
|
if isinstance(tensor_obj, torch.Tensor):
|
|
16
33
|
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
@@ -84,7 +101,7 @@ class AddNoiseLayer(NpuBaseLayer):
|
|
|
84
101
|
if max_val < abs_tol:
|
|
85
102
|
logger.warning_on_rank_0(
|
|
86
103
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
87
|
-
f"Maximun value is less than the
|
|
104
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
88
105
|
)
|
|
89
106
|
return False
|
|
90
107
|
return True
|
|
@@ -1,4 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
2
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
4
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -16,6 +32,7 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
16
32
|
self.bit_tail: int = 1
|
|
17
33
|
self.bit_type = None
|
|
18
34
|
|
|
35
|
+
@recursion_depth_decorator("FreeBenchmark: BitNoiseLayer.add_bit_noise")
|
|
19
36
|
def add_bit_noise(self, tensor_obj):
|
|
20
37
|
"""
|
|
21
38
|
对输入添加噪声
|
|
@@ -64,14 +81,14 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
64
81
|
判断是否需要添加扰动, bit翻转
|
|
65
82
|
"""
|
|
66
83
|
if not self.bit_type:
|
|
67
|
-
logger.
|
|
84
|
+
logger.warning_on_rank_0(
|
|
68
85
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
69
86
|
f"dtype unsupported. Cancel perturbation."
|
|
70
87
|
)
|
|
71
88
|
return False
|
|
72
89
|
if tensor_obj.numel() == 0:
|
|
73
90
|
logger.warning_on_rank_0(
|
|
74
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
|
|
91
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
75
92
|
f" Cancel adding noise."
|
|
76
93
|
)
|
|
77
94
|
return False
|
|
@@ -87,9 +104,9 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
87
104
|
)
|
|
88
105
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
89
106
|
if max_val < abs_tol:
|
|
90
|
-
logger.
|
|
107
|
+
logger.warning_on_rank_0(
|
|
91
108
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
92
|
-
f"Maximun value is less than the
|
|
109
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
93
110
|
)
|
|
94
111
|
return False
|
|
95
112
|
return True
|
|
@@ -1,4 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
2
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
19
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
4
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -14,6 +30,7 @@ class ChangeValueLayer(NpuBaseLayer):
|
|
|
14
30
|
self.head: int = 0
|
|
15
31
|
self.tail: int = -1
|
|
16
32
|
|
|
33
|
+
@recursion_depth_decorator("FreeBenchmark: ChangeValueLayer.change_value")
|
|
17
34
|
def change_value(self, tensor_obj):
|
|
18
35
|
"""
|
|
19
36
|
交换张量首尾
|
|
@@ -54,10 +71,19 @@ class ChangeValueLayer(NpuBaseLayer):
|
|
|
54
71
|
"""
|
|
55
72
|
判断是否需要添加扰动, 首尾值交换
|
|
56
73
|
"""
|
|
57
|
-
|
|
74
|
+
# 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
|
|
75
|
+
if tensor_obj.ndim > 1:
|
|
76
|
+
if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
|
|
77
|
+
logger.info_on_rank_0(
|
|
78
|
+
f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
|
|
79
|
+
f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
|
|
80
|
+
)
|
|
81
|
+
return False
|
|
82
|
+
# 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
|
|
83
|
+
elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
|
|
58
84
|
logger.info_on_rank_0(
|
|
59
85
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
60
|
-
f"
|
|
86
|
+
f"0-dimension must greater than 1. Cancel change value."
|
|
61
87
|
)
|
|
62
88
|
return False
|
|
63
89
|
return True
|
|
@@ -1,5 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
3
19
|
from msprobe.pytorch.free_benchmark import logger
|
|
4
20
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
5
21
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -11,6 +27,9 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
|
|
|
11
27
|
|
|
12
28
|
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
13
29
|
|
|
30
|
+
@recursion_depth_decorator(
|
|
31
|
+
"FreeBenchmark: ImprovePrecisionLayer.improve_tensor_precision"
|
|
32
|
+
)
|
|
14
33
|
def improve_tensor_precision(self, tensor_obj):
|
|
15
34
|
if (
|
|
16
35
|
isinstance(tensor_obj, torch.Tensor)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import abstractmethod
|
|
2
17
|
from typing import Any
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
from msprobe.pytorch.free_benchmark import logger
|
|
3
18
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -1,10 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
from abc import ABC, abstractmethod
|
|
3
18
|
from typing import Any, Optional, Tuple
|
|
4
|
-
import numpy as np
|
|
5
19
|
|
|
20
|
+
import numpy as np
|
|
6
21
|
import torch
|
|
7
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
8
24
|
from msprobe.pytorch.free_benchmark import logger
|
|
9
25
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
10
26
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
@@ -35,7 +51,9 @@ class FuzzHandler(ABC):
|
|
|
35
51
|
origin_ouput = origin_ouput.values
|
|
36
52
|
perturbed_output = perturbed_output.values
|
|
37
53
|
if hasattr(perturbed_output, "dtype"):
|
|
38
|
-
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
54
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
55
|
+
perturbed_output.dtype, FuzzThreshold.F32_THD
|
|
56
|
+
)
|
|
39
57
|
else:
|
|
40
58
|
abs_tol = FuzzThreshold.F32_THD
|
|
41
59
|
return (
|
|
@@ -53,16 +71,22 @@ class FuzzHandler(ABC):
|
|
|
53
71
|
:return origin_output_chunks: 切块后原始输出列表
|
|
54
72
|
:return perturbed_output_chunks: 切块后扰动后输出列表
|
|
55
73
|
"""
|
|
56
|
-
single_output_mem =
|
|
74
|
+
single_output_mem = (
|
|
75
|
+
origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
|
|
76
|
+
)
|
|
57
77
|
if single_output_mem == 0 or origin_output.ndim == 0:
|
|
58
78
|
return [origin_output], [perturbed_output]
|
|
59
79
|
# 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
|
|
60
80
|
chunks_exp = int(math.log(single_output_mem, 2)) - 4
|
|
61
|
-
chunks = 2
|
|
81
|
+
chunks = 2**chunks_exp
|
|
62
82
|
chunks = max(chunks, 1)
|
|
63
83
|
chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
|
|
64
|
-
origin_output_chunks = TorchC.tensor_split(
|
|
65
|
-
|
|
84
|
+
origin_output_chunks = TorchC.tensor_split(
|
|
85
|
+
TorchC.reshape(origin_output, (-1,)), chunks
|
|
86
|
+
)
|
|
87
|
+
perturbed_output_chunks = TorchC.tensor_split(
|
|
88
|
+
TorchC.reshape(perturbed_output, (-1,)), chunks
|
|
89
|
+
)
|
|
66
90
|
return origin_output_chunks, perturbed_output_chunks
|
|
67
91
|
|
|
68
92
|
@staticmethod
|
|
@@ -80,14 +104,24 @@ class FuzzHandler(ABC):
|
|
|
80
104
|
pass
|
|
81
105
|
|
|
82
106
|
def get_ratio_from_specific_norm(
|
|
83
|
-
|
|
107
|
+
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
84
108
|
):
|
|
85
109
|
if norm_type == NormType.ENDLESS_NORM:
|
|
86
110
|
return self.calculate_error(origin_output, perturbed_output, abs_tol)
|
|
87
111
|
return ThresholdConfig.COMP_CONSISTENT
|
|
88
112
|
|
|
89
113
|
def calculate_error(self, origin_output, perturbed_output, abs_tol):
|
|
90
|
-
origin_output_chunks, perturbed_output_chunks =
|
|
114
|
+
origin_output_chunks, perturbed_output_chunks = (
|
|
115
|
+
self.tensor_split_for_error_calculate(origin_output, perturbed_output)
|
|
116
|
+
)
|
|
117
|
+
if len(origin_output_chunks) != len(perturbed_output_chunks):
|
|
118
|
+
err_msg = (
|
|
119
|
+
f"For {self.params.api_name}, the number of compare tensor chunks is different: "
|
|
120
|
+
f"{len(origin_output_chunks)} != {len(perturbed_output_chunks)}. please check!"
|
|
121
|
+
)
|
|
122
|
+
raise FreeBenchmarkException(
|
|
123
|
+
FreeBenchmarkException.OutputIndexError, err_msg
|
|
124
|
+
)
|
|
91
125
|
norm1 = -np.inf
|
|
92
126
|
norm2 = -np.inf
|
|
93
127
|
norm3 = np.inf
|
|
@@ -95,11 +129,25 @@ class FuzzHandler(ABC):
|
|
|
95
129
|
if chunk_origin.nelement() == 0:
|
|
96
130
|
break
|
|
97
131
|
chunk_perturbed = perturbed_output_chunks[i]
|
|
98
|
-
ratio_tensor1 = TorchC.where(
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
132
|
+
ratio_tensor1 = TorchC.where(
|
|
133
|
+
TorchC.abs(chunk_perturbed) > abs_tol,
|
|
134
|
+
TorchC.div(
|
|
135
|
+
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
136
|
+
TorchC.clamp(chunk_perturbed, min=abs_tol),
|
|
137
|
+
),
|
|
138
|
+
1,
|
|
139
|
+
)
|
|
140
|
+
ratio_tensor2 = TorchC.where(
|
|
141
|
+
TorchC.abs(chunk_origin) > abs_tol,
|
|
142
|
+
TorchC.div(
|
|
143
|
+
TorchC.clamp(chunk_perturbed, min=abs_tol),
|
|
144
|
+
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
145
|
+
),
|
|
146
|
+
1,
|
|
147
|
+
)
|
|
148
|
+
norm_values = TorchC.stack(
|
|
149
|
+
[TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
|
|
150
|
+
)
|
|
103
151
|
max_ratio1, max_ratio2 = norm_values.tolist()
|
|
104
152
|
norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
105
153
|
norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
|
|
@@ -126,13 +174,13 @@ class FuzzHandler(ABC):
|
|
|
126
174
|
if self.params.fuzz_stage == Const.BACKWARD:
|
|
127
175
|
abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
|
|
128
176
|
else:
|
|
129
|
-
abs_tol = abs_tol
|
|
177
|
+
abs_tol = abs_tol**0.5
|
|
130
178
|
return self.get_ratio_from_specific_norm(
|
|
131
179
|
origin_output, perturbed_output, norm_type, abs_tol
|
|
132
180
|
)
|
|
133
181
|
|
|
134
182
|
def npu_compare(
|
|
135
|
-
|
|
183
|
+
self, origin_output, perturbed_output
|
|
136
184
|
) -> Tuple[bool, Optional[float]]:
|
|
137
185
|
|
|
138
186
|
if isinstance(perturbed_output, int):
|
|
@@ -150,6 +198,7 @@ class FuzzHandler(ABC):
|
|
|
150
198
|
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
151
199
|
f"The compare for output type {type(perturbed_output)} is not supported"
|
|
152
200
|
)
|
|
201
|
+
return True, 1
|
|
153
202
|
|
|
154
203
|
threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
|
|
155
204
|
ratio = self.ratio_calculate(
|
|
@@ -189,7 +238,7 @@ class FuzzHandler(ABC):
|
|
|
189
238
|
max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
|
|
190
239
|
)
|
|
191
240
|
data_params.is_consistent = (
|
|
192
|
-
|
|
241
|
+
is_consistent and data_params.is_consistent
|
|
193
242
|
)
|
|
194
243
|
if not is_consistent and data_params.grad_unequal_flag:
|
|
195
244
|
self.unequal_rows.append(
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
from msprobe.pytorch.free_benchmark import logger
|