mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,6 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
3
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
4
21
|
|
|
5
22
|
|
|
6
23
|
class DebuggerConfig:
|
|
@@ -10,30 +27,28 @@ class DebuggerConfig:
|
|
|
10
27
|
self.rank = common_config.rank if common_config.rank else []
|
|
11
28
|
self.step = common_config.step if common_config.step else []
|
|
12
29
|
self.level = level or common_config.level or "L1"
|
|
13
|
-
self.seed = common_config.seed if common_config.seed else 1234
|
|
14
|
-
self.is_deterministic = common_config.is_deterministic
|
|
15
30
|
self.enable_dataloader = common_config.enable_dataloader
|
|
16
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
17
32
|
self.list = task_config.list if task_config.list else []
|
|
18
33
|
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
19
|
-
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
20
|
-
self.backward_input = {}
|
|
21
|
-
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
22
|
-
self.is_forward_acl_dump = True
|
|
23
34
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
24
35
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
25
36
|
self.framework = Const.PT_FRAMEWORK
|
|
26
37
|
|
|
38
|
+
if self.level == Const.LEVEL_L2:
|
|
39
|
+
self.is_backward_kernel_dump = False
|
|
40
|
+
self._check_and_adjust_config_with_l2()
|
|
41
|
+
|
|
27
42
|
if self.task == Const.FREE_BENCHMARK:
|
|
28
|
-
self.fuzz_device = task_config.fuzz_device
|
|
29
|
-
self.handler_type = task_config.handler_type
|
|
30
|
-
self.pert_mode = task_config.pert_mode
|
|
31
|
-
self.fuzz_level = task_config.fuzz_level
|
|
32
|
-
self.fuzz_stage = task_config.fuzz_stage
|
|
43
|
+
self.fuzz_device = task_config.fuzz_device
|
|
44
|
+
self.handler_type = task_config.handler_type
|
|
45
|
+
self.pert_mode = task_config.pert_mode
|
|
46
|
+
self.fuzz_level = task_config.fuzz_level
|
|
47
|
+
self.fuzz_stage = task_config.fuzz_stage
|
|
33
48
|
self.preheat_config = {
|
|
34
|
-
"if_preheat": task_config.if_preheat
|
|
35
|
-
"preheat_step": task_config.preheat_step
|
|
36
|
-
"max_sample": task_config.max_sample
|
|
49
|
+
"if_preheat": task_config.if_preheat,
|
|
50
|
+
"preheat_step": task_config.preheat_step,
|
|
51
|
+
"max_sample": task_config.max_sample
|
|
37
52
|
}
|
|
38
53
|
|
|
39
54
|
self.online_run_ut = False
|
|
@@ -44,52 +59,54 @@ class DebuggerConfig:
|
|
|
44
59
|
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
45
60
|
self.host = task_config.host if task_config.host else ""
|
|
46
61
|
self.port = task_config.port if task_config.port else -1
|
|
62
|
+
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
63
|
+
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
47
64
|
|
|
48
65
|
self.check()
|
|
49
|
-
if self.step:
|
|
50
|
-
self.step.sort()
|
|
51
|
-
if self.level == "L2":
|
|
52
|
-
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
53
|
-
raise ValueError("scope must be configured as a list with one api name")
|
|
54
|
-
if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
|
|
55
|
-
raise ValueError("backward_input must be configured when scope contains 'backward'")
|
|
56
|
-
if Const.BACKWARD in self.scope[0]:
|
|
57
|
-
self.is_forward_acl_dump = False
|
|
58
|
-
for index, scope_spec in enumerate(self.scope):
|
|
59
|
-
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
60
|
-
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
61
|
-
seed_all(self.seed, self.is_deterministic)
|
|
62
66
|
|
|
63
67
|
def check_kwargs(self):
|
|
64
68
|
if self.task and self.task not in Const.TASK_LIST:
|
|
65
|
-
raise
|
|
69
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
70
|
+
f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
|
|
66
71
|
if self.level and self.level not in Const.LEVEL_LIST:
|
|
67
|
-
raise
|
|
72
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
73
|
+
f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
|
|
68
74
|
if not self.dump_path:
|
|
69
|
-
raise
|
|
75
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
76
|
+
f"The dump_path not found.")
|
|
70
77
|
|
|
71
78
|
def check(self):
|
|
72
79
|
self.check_kwargs()
|
|
73
|
-
self._check_rank()
|
|
74
|
-
self._check_step()
|
|
75
80
|
return True
|
|
76
81
|
|
|
77
|
-
def check_model(self,
|
|
78
|
-
if self.level in ["L0", "mix"]
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
82
|
+
def check_model(self, instance, start_model):
|
|
83
|
+
if self.level not in ["L0", "mix"]:
|
|
84
|
+
if instance.model is not None or start_model is not None:
|
|
85
|
+
logger.warning_on_rank_0(
|
|
86
|
+
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
87
|
+
return
|
|
88
|
+
if start_model is None:
|
|
89
|
+
if instance.model is None:
|
|
90
|
+
logger.error_on_rank_0(
|
|
91
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
|
|
92
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
93
|
+
return
|
|
94
|
+
if isinstance(start_model, torch.nn.Module):
|
|
95
|
+
instance.model = start_model
|
|
96
|
+
else:
|
|
97
|
+
logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
|
|
98
|
+
raise MsprobeException(
|
|
99
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|
|
90
100
|
|
|
91
|
-
def
|
|
92
|
-
if self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
101
|
+
def _check_and_adjust_config_with_l2(self):
|
|
102
|
+
if self.scope:
|
|
103
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
104
|
+
f"When level is set to L2, the scope cannot be configured.")
|
|
105
|
+
if not self.list or len(self.list) != 1:
|
|
106
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
107
|
+
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
108
|
+
api_name = self.list[0]
|
|
109
|
+
if api_name.endswith(Const.BACKWARD):
|
|
110
|
+
self.is_backward_kernel_dump = True
|
|
111
|
+
api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
|
|
112
|
+
self.list.append(api_forward_name)
|
|
@@ -1,12 +1,34 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
|
|
1
18
|
import torch
|
|
2
|
-
from
|
|
3
|
-
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
4
|
-
from msprobe.pytorch.service import Service
|
|
5
|
-
from msprobe.pytorch.common.log import logger
|
|
6
|
-
from msprobe.pytorch.pt_config import parse_json_config
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
7
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
8
|
-
from msprobe.core.common.
|
|
21
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
22
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
9
25
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
|
+
from msprobe.pytorch.service import Service
|
|
28
|
+
from torch.utils.data import dataloader
|
|
29
|
+
|
|
30
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
|
|
31
|
+
"dump_path", "level", "model"])
|
|
10
32
|
|
|
11
33
|
|
|
12
34
|
class PrecisionDebugger:
|
|
@@ -30,20 +52,26 @@ class PrecisionDebugger:
|
|
|
30
52
|
step=None,
|
|
31
53
|
):
|
|
32
54
|
if not hasattr(self, "initialized"):
|
|
55
|
+
config_params = ConfigParameters(config_path,
|
|
56
|
+
task,
|
|
57
|
+
dump_path,
|
|
58
|
+
level,
|
|
59
|
+
model)
|
|
60
|
+
self.check_input_params(config_params)
|
|
61
|
+
|
|
33
62
|
self.api_origin = False
|
|
34
63
|
self.initialized = True
|
|
35
|
-
self.model =
|
|
64
|
+
self.model = model
|
|
36
65
|
common_config, task_config = parse_json_config(config_path, task)
|
|
37
|
-
self.task = common_config.task
|
|
66
|
+
self.task = task if task else common_config.task
|
|
38
67
|
if self.task == Const.GRAD_PROBE:
|
|
39
68
|
self.gm = GradientMonitor(common_config, task_config)
|
|
40
69
|
return
|
|
41
70
|
if step:
|
|
42
|
-
common_config.step = step
|
|
71
|
+
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
43
72
|
self.config = DebuggerConfig(
|
|
44
73
|
common_config, task_config, task, dump_path, level
|
|
45
74
|
)
|
|
46
|
-
self.config.check_model(self.model)
|
|
47
75
|
self.service = Service(self.config)
|
|
48
76
|
self.enable_dataloader = self.config.enable_dataloader
|
|
49
77
|
if self.enable_dataloader:
|
|
@@ -55,20 +83,40 @@ class PrecisionDebugger:
|
|
|
55
83
|
return self._instance
|
|
56
84
|
|
|
57
85
|
@staticmethod
|
|
58
|
-
def
|
|
59
|
-
if
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
86
|
+
def check_input_params(args):
|
|
87
|
+
if args.config_path is not None:
|
|
88
|
+
if not isinstance(args.config_path, str):
|
|
89
|
+
raise MsprobeException(
|
|
90
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
91
|
+
file_checker = FileChecker(
|
|
92
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
93
|
+
file_checker.common_check()
|
|
94
|
+
|
|
95
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
96
|
+
raise MsprobeException(
|
|
97
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
98
|
+
|
|
99
|
+
if args.dump_path is not None:
|
|
100
|
+
if not isinstance(args.dump_path, str):
|
|
101
|
+
raise MsprobeException(
|
|
102
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
103
|
+
|
|
104
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
105
|
+
raise MsprobeException(
|
|
106
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
|
+
|
|
108
|
+
if args.model is not None and not isinstance(args.model, torch.nn.Module):
|
|
109
|
+
raise MsprobeException(
|
|
110
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|
|
64
111
|
|
|
65
112
|
@classmethod
|
|
66
|
-
def start(cls):
|
|
113
|
+
def start(cls, model=None):
|
|
67
114
|
instance = cls._instance
|
|
115
|
+
if not instance:
|
|
116
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
68
117
|
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
69
118
|
return
|
|
70
|
-
|
|
71
|
-
raise Exception("No instance of PrecisionDebugger found.")
|
|
119
|
+
instance.config.check_model(instance, model)
|
|
72
120
|
if instance.enable_dataloader:
|
|
73
121
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
74
122
|
else:
|
|
@@ -85,10 +133,10 @@ class PrecisionDebugger:
|
|
|
85
133
|
@classmethod
|
|
86
134
|
def stop(cls):
|
|
87
135
|
instance = cls._instance
|
|
136
|
+
if not instance:
|
|
137
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
88
138
|
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
89
139
|
return
|
|
90
|
-
if not instance:
|
|
91
|
-
raise Exception("PrecisionDebugger instance is not created.")
|
|
92
140
|
if instance.enable_dataloader:
|
|
93
141
|
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
94
142
|
else:
|
|
@@ -96,16 +144,16 @@ class PrecisionDebugger:
|
|
|
96
144
|
|
|
97
145
|
@classmethod
|
|
98
146
|
def step(cls):
|
|
147
|
+
if not cls._instance:
|
|
148
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
99
149
|
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
100
150
|
return
|
|
101
|
-
if not cls._instance:
|
|
102
|
-
raise Exception("PrecisionDebugger instance is not created.")
|
|
103
151
|
cls._instance.service.step()
|
|
104
152
|
|
|
105
153
|
@classmethod
|
|
106
154
|
def monitor(cls, model):
|
|
107
155
|
if not cls._instance:
|
|
108
|
-
raise Exception(
|
|
156
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
109
157
|
if cls._instance.task != Const.GRAD_PROBE:
|
|
110
158
|
return
|
|
111
159
|
cls._instance.gm.monitor(model)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import save_json
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_kernel_config_json(dump_path, cur_rank):
|
|
22
|
+
kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
|
|
23
|
+
kernel_config_path = os.path.join(dump_path, kernel_config_name)
|
|
24
|
+
config_info = {
|
|
25
|
+
"dump": {
|
|
26
|
+
"dump_list": [],
|
|
27
|
+
"dump_path": dump_path,
|
|
28
|
+
"dump_mode": "all",
|
|
29
|
+
"dump_op_switch": "on"
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
save_json(kernel_config_path, config_info, indent=4)
|
|
33
|
+
return kernel_config_path
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
__all__ = ["FreeBenchmarkCheck", "UnequalRow"]
|
|
17
|
+
|
|
3
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
4
21
|
|
|
5
|
-
from .main import FreeBenchmarkCheck
|
|
6
22
|
from .common.params import UnequalRow
|
|
7
|
-
|
|
8
|
-
__all__ = [FreeBenchmarkCheck, UnequalRow]
|
|
23
|
+
from .main import FreeBenchmarkCheck
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Dict
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from collections import defaultdict
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
3
18
|
|
|
@@ -1,3 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.const import Const
|
|
17
|
+
|
|
18
|
+
|
|
1
19
|
class PerturbationMode:
|
|
2
20
|
ADD_NOISE = "add_noise"
|
|
3
21
|
CHANGE_VALUE = "change_value"
|
|
@@ -35,3 +53,28 @@ class FuzzLevel:
|
|
|
35
53
|
BASE_LEVEL = "L1"
|
|
36
54
|
ADV_LEVEL = "L2"
|
|
37
55
|
REAL_LEVEL = "L3"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PytorchFreeBenchmarkConst:
|
|
59
|
+
PERTURBATION_MODE_LIST = [
|
|
60
|
+
PerturbationMode.ADD_NOISE,
|
|
61
|
+
PerturbationMode.CHANGE_VALUE,
|
|
62
|
+
PerturbationMode.IMPROVE_PRECISION,
|
|
63
|
+
PerturbationMode.NO_CHANGE,
|
|
64
|
+
PerturbationMode.BIT_NOISE,
|
|
65
|
+
PerturbationMode.TO_CPU,
|
|
66
|
+
]
|
|
67
|
+
DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION
|
|
68
|
+
DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU]
|
|
69
|
+
DEFAULT_DEVICE = DeviceType.NPU
|
|
70
|
+
HANDLER_LIST = [HandlerType.CHECK, HandlerType.FIX]
|
|
71
|
+
DEFAULT_HANDLER = HandlerType.CHECK
|
|
72
|
+
FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL]
|
|
73
|
+
DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL
|
|
74
|
+
FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
75
|
+
FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU]
|
|
76
|
+
DEFAULT_FUZZ_STAGE = Const.FORWARD
|
|
77
|
+
DEFAULT_PREHEAT_STEP = 15
|
|
78
|
+
DEFAULT_MAX_SAMPLE = 20
|
|
79
|
+
CPU_MODE_LIST = [PerturbationMode.TO_CPU]
|
|
80
|
+
FIX_STAGE_LIST = [Const.FORWARD]
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from dataclasses import dataclass
|
|
2
17
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
3
18
|
|
|
4
19
|
import torch
|
|
20
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
5
21
|
from msprobe.pytorch.free_benchmark import logger
|
|
6
22
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
7
23
|
DeviceType,
|
|
@@ -113,7 +129,13 @@ def make_unequal_row(
|
|
|
113
129
|
row.max_rel = ratio - 1
|
|
114
130
|
origin_tensor = data_params.original_result
|
|
115
131
|
perturbed_tensor = data_params.perturbed_result
|
|
116
|
-
if index:
|
|
132
|
+
if index is not None:
|
|
133
|
+
if index >= len(origin_tensor) or index >= len(perturbed_tensor):
|
|
134
|
+
err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
|
|
135
|
+
raise FreeBenchmarkException(
|
|
136
|
+
FreeBenchmarkException.OutputIndexError,
|
|
137
|
+
error_info=err_msg,
|
|
138
|
+
)
|
|
117
139
|
origin_tensor = origin_tensor[index]
|
|
118
140
|
perturbed_tensor = perturbed_tensor[index]
|
|
119
141
|
row.output_index = index
|
|
@@ -1,4 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
1
17
|
import torch
|
|
18
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
2
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
3
21
|
|
|
4
22
|
|
|
@@ -36,6 +54,7 @@ class Tools:
|
|
|
36
54
|
return api_name.rsplit(".", 2)[0]
|
|
37
55
|
|
|
38
56
|
@staticmethod
|
|
57
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
|
|
39
58
|
def convert_device_and_dtype(
|
|
40
59
|
tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
|
|
41
60
|
):
|
|
@@ -58,24 +77,43 @@ class Tools:
|
|
|
58
77
|
return tensor_seq
|
|
59
78
|
|
|
60
79
|
@staticmethod
|
|
80
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
|
|
61
81
|
def convert_fuzz_output_to_origin(origin, perturbed):
|
|
62
|
-
if isinstance(origin, torch.Tensor):
|
|
82
|
+
if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
|
|
63
83
|
origin.data = perturbed.to(origin.dtype).to(origin.device)
|
|
64
84
|
return origin
|
|
65
|
-
if isinstance(origin, dict):
|
|
85
|
+
if isinstance(origin, dict) and isinstance(perturbed, dict):
|
|
66
86
|
output = dict()
|
|
67
87
|
for key, value in origin.items():
|
|
88
|
+
if key not in perturbed:
|
|
89
|
+
err_msg = f"'{key}' not in perturbed output."
|
|
90
|
+
raise FreeBenchmarkException(
|
|
91
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
92
|
+
error_info=err_msg,
|
|
93
|
+
)
|
|
68
94
|
output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
|
|
69
95
|
return output
|
|
70
|
-
if isinstance(origin, (tuple, list)):
|
|
96
|
+
if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
|
|
71
97
|
result = list()
|
|
98
|
+
if len(perturbed) != len(origin):
|
|
99
|
+
err_msg = (
|
|
100
|
+
f"length of perturbed output ({len(perturbed)}) is different "
|
|
101
|
+
f"from the length of original output ({len(origin)})."
|
|
102
|
+
)
|
|
103
|
+
raise FreeBenchmarkException(
|
|
104
|
+
FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
|
|
105
|
+
)
|
|
72
106
|
for index_, value in enumerate(origin):
|
|
73
107
|
result.append(
|
|
74
108
|
Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
|
|
75
109
|
)
|
|
76
110
|
return type(origin)(result)
|
|
77
|
-
|
|
78
|
-
|
|
111
|
+
err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
|
|
112
|
+
raise FreeBenchmarkException(
|
|
113
|
+
FreeBenchmarkException.UnsupportedType, error_info=err_msg
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
79
117
|
class TorchC:
|
|
80
118
|
sum = torch._C._VariableFunctionsClass.sum
|
|
81
119
|
isinf = torch._C._VariableFunctionsClass.isinf
|