mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
21
|
+
|
|
22
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModuleDumper:
|
|
26
|
+
def __init__(self, service):
|
|
27
|
+
self.service = service
|
|
28
|
+
self.hook_handle_list = []
|
|
29
|
+
|
|
30
|
+
def start_module_dump(self, module, dump_name):
|
|
31
|
+
api_register.api_originality()
|
|
32
|
+
self.register_hook(module, dump_name)
|
|
33
|
+
|
|
34
|
+
def stop_module_dump(self):
|
|
35
|
+
api_register.api_modularity()
|
|
36
|
+
for hook_handle in self.hook_handle_list:
|
|
37
|
+
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
|
+
hook_handle.remove()
|
|
39
|
+
self.hook_handle_list.clear()
|
|
40
|
+
|
|
41
|
+
def register_hook(self, module, dump_name):
|
|
42
|
+
prefix_name = (
|
|
43
|
+
BaseScope.Module_Type_Module + Const.SEP +
|
|
44
|
+
dump_name + Const.SEP +
|
|
45
|
+
module.__class__.__name__ + Const.SEP
|
|
46
|
+
)
|
|
47
|
+
module_processor = self.service.module_processor
|
|
48
|
+
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
|
|
49
|
+
BaseScope.Module_Type_Module,
|
|
50
|
+
prefix_name
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if module_processor.has_register_backward_hook(module):
|
|
54
|
+
logger.warning(
|
|
55
|
+
f"The {dump_name} module has registered deprecated register_backward_hook,"
|
|
56
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
57
|
+
)
|
|
58
|
+
if torch_version_above_or_equal_2:
|
|
59
|
+
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
60
|
+
else:
|
|
61
|
+
if not module_processor.has_register_backward_hook(module):
|
|
62
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
63
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
64
|
+
)
|
|
65
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
66
|
+
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
67
|
+
self.hook_handle_list.append(forward_hook_handle)
|
|
68
|
+
if not module_processor.has_register_backward_hook(module):
|
|
69
|
+
backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
70
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
71
|
+
|
|
72
|
+
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
73
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
|
|
74
|
+
)
|
|
75
|
+
forward_hook_handle = module.register_forward_hook(
|
|
76
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
|
|
77
|
+
)
|
|
78
|
+
self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
|
|
79
|
+
if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
|
|
80
|
+
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
81
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
|
|
82
|
+
)
|
|
83
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
84
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
85
|
+
)
|
|
86
|
+
self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -17,12 +17,24 @@ from functools import wraps
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.data_dump.scope import ModuleRangeScope
|
|
20
|
+
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint
|
|
23
|
+
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
21
24
|
from torch.utils.hooks import BackwardHook
|
|
22
25
|
|
|
23
26
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
27
|
|
|
25
28
|
|
|
29
|
+
def checkpoint_without_early_stop(*args, **kwargs):
|
|
30
|
+
with set_checkpoint_early_stop(False):
|
|
31
|
+
return origin_checkpoint(*args, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def replace_checkpoint():
|
|
35
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
36
|
+
|
|
37
|
+
|
|
26
38
|
class ModuleProcesser:
|
|
27
39
|
module_count = {}
|
|
28
40
|
module_stack = []
|
|
@@ -30,13 +42,11 @@ class ModuleProcesser:
|
|
|
30
42
|
module_node = {}
|
|
31
43
|
|
|
32
44
|
def __init__(self, scope):
|
|
33
|
-
if isinstance(scope, ModuleRangeScope)
|
|
34
|
-
self.scope = scope
|
|
35
|
-
else:
|
|
36
|
-
self.scope = None
|
|
45
|
+
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
37
46
|
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
38
47
|
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
39
48
|
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
|
|
49
|
+
replace_checkpoint()
|
|
40
50
|
|
|
41
51
|
@staticmethod
|
|
42
52
|
def filter_tensor_and_tuple(func):
|
|
@@ -66,7 +76,7 @@ class ModuleProcesser:
|
|
|
66
76
|
return ModuleProcesser.clone_if_tensor(result)
|
|
67
77
|
|
|
68
78
|
return clone_return_value_func
|
|
69
|
-
|
|
79
|
+
|
|
70
80
|
@staticmethod
|
|
71
81
|
def clone_if_tensor(result):
|
|
72
82
|
if isinstance(result, torch.Tensor):
|
|
@@ -88,6 +98,22 @@ class ModuleProcesser:
|
|
|
88
98
|
ModuleProcesser.module_count[module_name] += 1
|
|
89
99
|
return ModuleProcesser.module_count[module_name]
|
|
90
100
|
|
|
101
|
+
@staticmethod
|
|
102
|
+
def has_register_backward_hook(module):
|
|
103
|
+
return hasattr(module, '_backward_hooks') and \
|
|
104
|
+
len(module._backward_hooks) > 0 and \
|
|
105
|
+
module._is_full_backward_hook is False
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def get_modules_and_names(models):
|
|
109
|
+
modules_and_names_with_index = {}
|
|
110
|
+
if isinstance(models, (list, tuple)):
|
|
111
|
+
for index, model in enumerate(models):
|
|
112
|
+
modules_and_names_with_index[str(index)] = model.named_modules()
|
|
113
|
+
else:
|
|
114
|
+
modules_and_names_with_index["-1"] = models.named_modules()
|
|
115
|
+
return modules_and_names_with_index
|
|
116
|
+
|
|
91
117
|
@classmethod
|
|
92
118
|
def reset_module_stats(cls):
|
|
93
119
|
cls.module_count = {}
|
|
@@ -95,6 +121,42 @@ class ModuleProcesser:
|
|
|
95
121
|
cls.api_parent_node = ""
|
|
96
122
|
cls.module_node = {}
|
|
97
123
|
|
|
124
|
+
def register_module_hook(self, models, build_hook):
|
|
125
|
+
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
126
|
+
modules_and_names_with_index = self.get_modules_and_names(models)
|
|
127
|
+
for index, modules_and_names in modules_and_names_with_index.items():
|
|
128
|
+
model = models if index == "-1" else models[int(index)]
|
|
129
|
+
for name, module in modules_and_names:
|
|
130
|
+
if module == model:
|
|
131
|
+
continue
|
|
132
|
+
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
133
|
+
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
134
|
+
name + Const.SEP + module.__class__.__name__ + Const.SEP)
|
|
135
|
+
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
|
|
136
|
+
BaseScope.Module_Type_Module,
|
|
137
|
+
prefix_name
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if self.has_register_backward_hook(module):
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
|
|
143
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
144
|
+
)
|
|
145
|
+
if torch_version_above_or_equal_2:
|
|
146
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
147
|
+
else:
|
|
148
|
+
if not self.has_register_backward_hook(module):
|
|
149
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
150
|
+
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
151
|
+
if not self.has_register_backward_hook(module):
|
|
152
|
+
module.register_full_backward_hook(backward_hook)
|
|
153
|
+
|
|
154
|
+
module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
|
|
155
|
+
module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
|
|
156
|
+
if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
|
|
157
|
+
module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
|
|
158
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
159
|
+
|
|
98
160
|
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
99
161
|
|
|
100
162
|
def pre_hook(module, input, output=None):
|
|
@@ -103,7 +165,10 @@ class ModuleProcesser:
|
|
|
103
165
|
except IndexError as e:
|
|
104
166
|
index = None
|
|
105
167
|
pass
|
|
106
|
-
|
|
168
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
169
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
170
|
+
module.mindstudio_reserved_name = []
|
|
171
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
107
172
|
if self.module_stack:
|
|
108
173
|
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
109
174
|
else:
|
|
@@ -122,8 +187,11 @@ class ModuleProcesser:
|
|
|
122
187
|
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
123
188
|
else:
|
|
124
189
|
ModuleProcesser.api_parent_node = None
|
|
190
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
191
|
+
raise RuntimeError(f"module reserve name is None when pop")
|
|
192
|
+
current_name = module.mindstudio_reserved_name.pop()
|
|
125
193
|
if self.scope:
|
|
126
|
-
self.scope.end_module(
|
|
194
|
+
self.scope.end_module(current_name)
|
|
127
195
|
|
|
128
196
|
def backward_hook(module, input, output=None):
|
|
129
197
|
try:
|
|
@@ -131,7 +199,10 @@ class ModuleProcesser:
|
|
|
131
199
|
except IndexError as e:
|
|
132
200
|
index = None
|
|
133
201
|
pass
|
|
134
|
-
|
|
202
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
203
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
204
|
+
module.mindstudio_reserved_name = []
|
|
205
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
135
206
|
forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
|
|
136
207
|
ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
|
|
137
208
|
Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Dict
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from collections import defaultdict
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const
|
|
2
17
|
|
|
3
18
|
|
|
@@ -17,6 +17,7 @@ from dataclasses import dataclass
|
|
|
17
17
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
20
21
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
22
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
22
23
|
DeviceType,
|
|
@@ -38,7 +39,6 @@ class DataParams:
|
|
|
38
39
|
origin_func: Optional[Callable] = None
|
|
39
40
|
api_type: Optional[str] = None
|
|
40
41
|
fuzz_stage: Optional[str] = None
|
|
41
|
-
grad_unequal_flag: Optional[bool] = True
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
@dataclass
|
|
@@ -126,9 +126,17 @@ def make_unequal_row(
|
|
|
126
126
|
)
|
|
127
127
|
if isinstance(ratio, float):
|
|
128
128
|
row.max_rel = ratio - 1
|
|
129
|
+
if isinstance(ratio, str):
|
|
130
|
+
row.max_rel = ratio
|
|
129
131
|
origin_tensor = data_params.original_result
|
|
130
132
|
perturbed_tensor = data_params.perturbed_result
|
|
131
|
-
if index:
|
|
133
|
+
if index is not None:
|
|
134
|
+
if index >= len(origin_tensor) or index >= len(perturbed_tensor):
|
|
135
|
+
err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
|
|
136
|
+
raise FreeBenchmarkException(
|
|
137
|
+
FreeBenchmarkException.OutputIndexError,
|
|
138
|
+
error_info=err_msg,
|
|
139
|
+
)
|
|
132
140
|
origin_tensor = origin_tensor[index]
|
|
133
141
|
perturbed_tensor = perturbed_tensor[index]
|
|
134
142
|
row.output_index = index
|
|
@@ -13,7 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
|
|
16
17
|
import torch
|
|
18
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
17
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
18
21
|
|
|
19
22
|
|
|
@@ -51,6 +54,7 @@ class Tools:
|
|
|
51
54
|
return api_name.rsplit(".", 2)[0]
|
|
52
55
|
|
|
53
56
|
@staticmethod
|
|
57
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
|
|
54
58
|
def convert_device_and_dtype(
|
|
55
59
|
tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
|
|
56
60
|
):
|
|
@@ -73,23 +77,41 @@ class Tools:
|
|
|
73
77
|
return tensor_seq
|
|
74
78
|
|
|
75
79
|
@staticmethod
|
|
80
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
|
|
76
81
|
def convert_fuzz_output_to_origin(origin, perturbed):
|
|
77
|
-
if isinstance(origin, torch.Tensor):
|
|
82
|
+
if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
|
|
78
83
|
origin.data = perturbed.to(origin.dtype).to(origin.device)
|
|
79
84
|
return origin
|
|
80
|
-
if isinstance(origin, dict):
|
|
85
|
+
if isinstance(origin, dict) and isinstance(perturbed, dict):
|
|
81
86
|
output = dict()
|
|
82
87
|
for key, value in origin.items():
|
|
88
|
+
if key not in perturbed:
|
|
89
|
+
err_msg = f"'{key}' not in perturbed output."
|
|
90
|
+
raise FreeBenchmarkException(
|
|
91
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
92
|
+
error_info=err_msg,
|
|
93
|
+
)
|
|
83
94
|
output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
|
|
84
95
|
return output
|
|
85
|
-
if isinstance(origin, (tuple, list)):
|
|
96
|
+
if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
|
|
86
97
|
result = list()
|
|
98
|
+
if len(perturbed) != len(origin):
|
|
99
|
+
err_msg = (
|
|
100
|
+
f"length of perturbed output ({len(perturbed)}) is different "
|
|
101
|
+
f"from the length of original output ({len(origin)})."
|
|
102
|
+
)
|
|
103
|
+
raise FreeBenchmarkException(
|
|
104
|
+
FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
|
|
105
|
+
)
|
|
87
106
|
for index_, value in enumerate(origin):
|
|
88
107
|
result.append(
|
|
89
108
|
Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
|
|
90
109
|
)
|
|
91
110
|
return type(origin)(result)
|
|
92
|
-
|
|
111
|
+
err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
|
|
112
|
+
raise FreeBenchmarkException(
|
|
113
|
+
FreeBenchmarkException.UnsupportedType, error_info=err_msg
|
|
114
|
+
)
|
|
93
115
|
|
|
94
116
|
|
|
95
117
|
class TorchC:
|
|
@@ -102,6 +124,7 @@ class TorchC:
|
|
|
102
124
|
abs = torch._C._VariableFunctionsClass.abs
|
|
103
125
|
where = torch._C._VariableFunctionsClass.where
|
|
104
126
|
div = torch._C._VariableFunctionsClass.div
|
|
127
|
+
mul = torch._C._VariableFunctionsClass.mul
|
|
105
128
|
max = torch._C._VariableFunctionsClass.max
|
|
106
129
|
min = torch._C._VariableFunctionsClass.min
|
|
107
130
|
gt = torch._C._VariableFunctionsClass.gt
|
|
@@ -116,3 +139,5 @@ class TorchC:
|
|
|
116
139
|
tensor_split = torch._C._VariableFunctionsClass.tensor_split
|
|
117
140
|
stack = torch._C._VariableFunctionsClass.stack
|
|
118
141
|
reshape = torch._C._VariableFunctionsClass.reshape
|
|
142
|
+
nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
|
|
143
|
+
aminmax = torch._C._VariableFunctionsClass.aminmax
|
|
@@ -82,13 +82,11 @@ class GradSaver:
|
|
|
82
82
|
data_params = DataParams()
|
|
83
83
|
data_params.original_result = origin_grad
|
|
84
84
|
data_params.perturbed_result = perturbed_grad
|
|
85
|
-
data_params.grad_unequal_flag = False
|
|
86
85
|
data_params.valid_input_index = index
|
|
87
86
|
try:
|
|
88
87
|
handler.handle(data_params)
|
|
89
88
|
if not data_params.is_consistent:
|
|
90
89
|
self.is_compare = False
|
|
91
|
-
data_params.grad_unequal_flag = True
|
|
92
90
|
data_params.is_consistent = True
|
|
93
91
|
data_params.perturbed_result = self.perturbed_grad_input
|
|
94
92
|
data_params.original_result = self.origin_grad_input
|
|
@@ -102,8 +100,13 @@ class GradSaver:
|
|
|
102
100
|
def check_grad_input(self, origin_grad, new_grad_index):
|
|
103
101
|
if self.perturbed_grad_input is None:
|
|
104
102
|
raise FreeBenchmarkException(
|
|
105
|
-
FreeBenchmarkException.
|
|
106
|
-
f"grad not exists
|
|
103
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
104
|
+
f"perturbed grad not exists for {self.api_name}.",
|
|
105
|
+
)
|
|
106
|
+
if len(self.perturbed_grad_input) <= new_grad_index:
|
|
107
|
+
raise FreeBenchmarkException(
|
|
108
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
109
|
+
f"perturbed grad index {new_grad_index} is out of bounds for {self.api_name}.",
|
|
107
110
|
)
|
|
108
111
|
with torch.no_grad():
|
|
109
112
|
perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
|
|
@@ -111,7 +114,7 @@ class GradSaver:
|
|
|
111
114
|
)
|
|
112
115
|
if origin_grad.shape != perturbed_grad.shape:
|
|
113
116
|
raise FreeBenchmarkException(
|
|
114
|
-
FreeBenchmarkException.
|
|
117
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
115
118
|
f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
|
|
116
119
|
f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
|
|
117
120
|
)
|
|
@@ -164,6 +167,18 @@ class GradSaver:
|
|
|
164
167
|
index_ = 0
|
|
165
168
|
for object_ in inner_args:
|
|
166
169
|
if object_ is CommonField.HOLD_PLACE:
|
|
170
|
+
if index_ >= len(inputs):
|
|
171
|
+
err_msg = (
|
|
172
|
+
f"[msprobe] Free benchmark: When getting input from vjp, "
|
|
173
|
+
f" the input index ({index_}) is out of bounds ({len(inputs)})."
|
|
174
|
+
)
|
|
175
|
+
logger.error_log_with_exp(
|
|
176
|
+
err_msg,
|
|
177
|
+
FreeBenchmarkException(
|
|
178
|
+
FreeBenchmarkException.InvalidGrad,
|
|
179
|
+
error_info=err_msg,
|
|
180
|
+
),
|
|
181
|
+
)
|
|
167
182
|
_real_input.append(inputs[index_])
|
|
168
183
|
index_ += 1
|
|
169
184
|
else:
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import math
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
19
20
|
from msprobe.pytorch.free_benchmark import logger
|
|
20
21
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
21
22
|
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
@@ -67,6 +68,7 @@ class SingleCompare:
|
|
|
67
68
|
return False
|
|
68
69
|
return True
|
|
69
70
|
|
|
71
|
+
@recursion_depth_decorator("FreeBenchmark: SingleCompare.compare_seq")
|
|
70
72
|
def compare_seq(self, actual, golden):
|
|
71
73
|
if isinstance(golden, torch.Tensor):
|
|
72
74
|
return self.compare_tensor_seq(actual, golden)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
17
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
18
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
19
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -26,6 +27,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
|
|
|
26
27
|
|
|
27
28
|
class AddNoiseLayer(NpuBaseLayer):
|
|
28
29
|
|
|
30
|
+
@recursion_depth_decorator("FreeBenchmark: AddNoiseLayer.add_noise")
|
|
29
31
|
def add_noise(self, tensor_obj):
|
|
30
32
|
if isinstance(tensor_obj, torch.Tensor):
|
|
31
33
|
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
@@ -99,7 +101,7 @@ class AddNoiseLayer(NpuBaseLayer):
|
|
|
99
101
|
if max_val < abs_tol:
|
|
100
102
|
logger.warning_on_rank_0(
|
|
101
103
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
102
|
-
f"Maximun value is less than the
|
|
104
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
103
105
|
)
|
|
104
106
|
return False
|
|
105
107
|
return True
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
17
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
18
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
19
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -31,6 +32,7 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
31
32
|
self.bit_tail: int = 1
|
|
32
33
|
self.bit_type = None
|
|
33
34
|
|
|
35
|
+
@recursion_depth_decorator("FreeBenchmark: BitNoiseLayer.add_bit_noise")
|
|
34
36
|
def add_bit_noise(self, tensor_obj):
|
|
35
37
|
"""
|
|
36
38
|
对输入添加噪声
|
|
@@ -79,14 +81,14 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
79
81
|
判断是否需要添加扰动, bit翻转
|
|
80
82
|
"""
|
|
81
83
|
if not self.bit_type:
|
|
82
|
-
logger.
|
|
84
|
+
logger.warning_on_rank_0(
|
|
83
85
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
84
86
|
f"dtype unsupported. Cancel perturbation."
|
|
85
87
|
)
|
|
86
88
|
return False
|
|
87
89
|
if tensor_obj.numel() == 0:
|
|
88
90
|
logger.warning_on_rank_0(
|
|
89
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
|
|
91
|
+
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
90
92
|
f" Cancel adding noise."
|
|
91
93
|
)
|
|
92
94
|
return False
|
|
@@ -102,9 +104,9 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
102
104
|
)
|
|
103
105
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
104
106
|
if max_val < abs_tol:
|
|
105
|
-
logger.
|
|
107
|
+
logger.warning_on_rank_0(
|
|
106
108
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
107
|
-
f"Maximun value is less than the
|
|
109
|
+
f"Maximun value is less than the minimun threshold. Cancel add noise."
|
|
108
110
|
)
|
|
109
111
|
return False
|
|
110
112
|
return True
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
17
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
18
19
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
19
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -29,6 +30,7 @@ class ChangeValueLayer(NpuBaseLayer):
|
|
|
29
30
|
self.head: int = 0
|
|
30
31
|
self.tail: int = -1
|
|
31
32
|
|
|
33
|
+
@recursion_depth_decorator("FreeBenchmark: ChangeValueLayer.change_value")
|
|
32
34
|
def change_value(self, tensor_obj):
|
|
33
35
|
"""
|
|
34
36
|
交换张量首尾
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
18
19
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
20
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
20
21
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -26,6 +27,9 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
|
|
|
26
27
|
|
|
27
28
|
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
28
29
|
|
|
30
|
+
@recursion_depth_decorator(
|
|
31
|
+
"FreeBenchmark: ImprovePrecisionLayer.improve_tensor_precision"
|
|
32
|
+
)
|
|
29
33
|
def improve_tensor_precision(self, tensor_obj):
|
|
30
34
|
if (
|
|
31
35
|
isinstance(tensor_obj, torch.Tensor)
|