mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -1,40 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
|
-
|
|
17
|
+
import re
|
|
18
|
+
|
|
3
19
|
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.exceptions import ScopeException
|
|
4
21
|
|
|
5
22
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
scope =
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
23
|
+
class ScopeFactory:
|
|
24
|
+
def __init__(self, config):
|
|
25
|
+
self.task = config.task
|
|
26
|
+
self.level = config.level
|
|
27
|
+
self.scope = config.scope
|
|
28
|
+
self.api_list = config.list
|
|
29
|
+
|
|
30
|
+
def build_scope(self):
|
|
31
|
+
if not self.scope and not self.api_list:
|
|
32
|
+
return None
|
|
33
|
+
if self.scope is None:
|
|
34
|
+
self.scope = []
|
|
35
|
+
if self.api_list is None:
|
|
36
|
+
self.api_list = []
|
|
37
|
+
if self.task == Const.FREE_BENCHMARK:
|
|
38
|
+
return ListScope(self.scope, self.api_list)
|
|
39
|
+
return self._build_range_scope()
|
|
40
|
+
|
|
41
|
+
def _build_range_scope(self):
|
|
42
|
+
api_range_scope = APIRangeScope(self.scope, self.api_list, self.level)
|
|
43
|
+
module_range_scope = ModuleRangeScope(self.scope, self.api_list, self.level)
|
|
44
|
+
mix_range_scope = MixRangeScope(self.scope, self.api_list, self.level)
|
|
45
|
+
|
|
46
|
+
if self.level == Const.LEVEL_MIX:
|
|
47
|
+
return mix_range_scope
|
|
48
|
+
|
|
49
|
+
if not self.scope:
|
|
50
|
+
return api_range_scope
|
|
51
|
+
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
52
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}.")
|
|
53
|
+
elif api_range_scope.is_valid:
|
|
54
|
+
return api_range_scope
|
|
55
|
+
elif module_range_scope.is_valid:
|
|
56
|
+
return module_range_scope
|
|
57
|
+
else:
|
|
58
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}")
|
|
31
59
|
|
|
32
60
|
|
|
33
61
|
class BaseScope(ABC):
|
|
34
62
|
Module_Type_Module = "Module"
|
|
35
63
|
Module_Type_API = "api"
|
|
64
|
+
module_type = ["Module", "Cell"]
|
|
36
65
|
|
|
37
|
-
def __init__(self, scope, api_list):
|
|
66
|
+
def __init__(self, scope, api_list, level=None):
|
|
67
|
+
self.level = level
|
|
38
68
|
scope, api_list = self.rectify_args(scope, api_list)
|
|
39
69
|
self.scope = scope
|
|
40
70
|
self.api_list = api_list
|
|
@@ -81,9 +111,9 @@ class ListScope(BaseScope):
|
|
|
81
111
|
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
82
112
|
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
83
113
|
|
|
84
|
-
def check(self,
|
|
85
|
-
if not self.scope or
|
|
86
|
-
return self.check_api_list(
|
|
114
|
+
def check(self, name):
|
|
115
|
+
if not self.scope or name in self.scope:
|
|
116
|
+
return self.check_api_list(name)
|
|
87
117
|
return False
|
|
88
118
|
|
|
89
119
|
|
|
@@ -92,19 +122,36 @@ class RangeScope(BaseScope, ABC):
|
|
|
92
122
|
def __init__(self, *args):
|
|
93
123
|
super().__init__(*args)
|
|
94
124
|
self.in_scope = False
|
|
125
|
+
self.in_list = False
|
|
95
126
|
self.is_valid = self.check_scope_is_valid()
|
|
96
127
|
|
|
128
|
+
def check_name_pattern(self, name):
|
|
129
|
+
options_pattern = "|".join(re.escape(option) for option in Const.DUMP_PREFIX)
|
|
130
|
+
api_pattern = rf"^({options_pattern})\..*\.\d+\.(forward|backward)$"
|
|
131
|
+
module_pattern = r"^(Cell|Module)\..*\.(forward|backward)\.\d+$"
|
|
97
132
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
101
|
-
if isinstance(scope, list):
|
|
102
|
-
if len(scope) == 1:
|
|
103
|
-
scope.append(scope[0])
|
|
104
|
-
elif len(scope) > 2:
|
|
133
|
+
if self.level == Const.LEVEL_L1:
|
|
134
|
+
if not re.match(api_pattern, name):
|
|
105
135
|
raise ScopeException(ScopeException.InvalidScope,
|
|
106
|
-
|
|
136
|
+
f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
|
|
137
|
+
|
|
138
|
+
if self.level == Const.LEVEL_L0:
|
|
139
|
+
if not re.match(module_pattern, name):
|
|
140
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
141
|
+
f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
|
|
142
|
+
|
|
143
|
+
if self.level == Const.LEVEL_MIX:
|
|
144
|
+
if not re.match(api_pattern, name) and not re.match(module_pattern, name):
|
|
145
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
146
|
+
f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
|
|
107
147
|
|
|
148
|
+
def rectify_args(self, scope, api_list):
|
|
149
|
+
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
150
|
+
if scope and len(scope) != 2:
|
|
151
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
152
|
+
f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
|
|
153
|
+
for name in scope:
|
|
154
|
+
self.check_name_pattern(name)
|
|
108
155
|
return scope, api_list
|
|
109
156
|
|
|
110
157
|
@abstractmethod
|
|
@@ -123,23 +170,23 @@ class APIRangeScope(RangeScope):
|
|
|
123
170
|
if not self.scope:
|
|
124
171
|
return True
|
|
125
172
|
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
126
|
-
if scope_start_type
|
|
173
|
+
if scope_start_type in BaseScope.module_type:
|
|
127
174
|
return False
|
|
128
175
|
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
129
|
-
if scope_stop_type
|
|
176
|
+
if scope_stop_type in BaseScope.module_type:
|
|
130
177
|
return False
|
|
131
178
|
return True
|
|
132
179
|
|
|
133
|
-
def check(self,
|
|
134
|
-
if self.scope and
|
|
180
|
+
def check(self, name):
|
|
181
|
+
if self.scope and name == self.scope[0]:
|
|
135
182
|
self.in_scope = True
|
|
136
183
|
|
|
137
184
|
if not self.scope or self.in_scope:
|
|
138
|
-
result = self.check_api_list(
|
|
185
|
+
result = self.check_api_list(name)
|
|
139
186
|
else:
|
|
140
187
|
result = False
|
|
141
188
|
|
|
142
|
-
if self.scope and
|
|
189
|
+
if self.scope and name == self.scope[1]:
|
|
143
190
|
self.in_scope = False
|
|
144
191
|
return result
|
|
145
192
|
|
|
@@ -150,13 +197,14 @@ class ModuleRangeScope(RangeScope):
|
|
|
150
197
|
需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
|
|
151
198
|
在这些hook触发时调用begin_module和end_module做区间控制
|
|
152
199
|
"""
|
|
200
|
+
|
|
153
201
|
def check_scope_is_valid(self):
|
|
154
202
|
if not self.scope:
|
|
155
203
|
return True
|
|
156
204
|
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
157
205
|
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
158
|
-
if scope_start_type
|
|
159
|
-
scope_stop_type
|
|
206
|
+
if scope_start_type in BaseScope.module_type and \
|
|
207
|
+
scope_stop_type in BaseScope.module_type:
|
|
160
208
|
return True
|
|
161
209
|
return False
|
|
162
210
|
|
|
@@ -172,7 +220,54 @@ class ModuleRangeScope(RangeScope):
|
|
|
172
220
|
if module_name == self.scope[1]:
|
|
173
221
|
self.in_scope = False
|
|
174
222
|
|
|
175
|
-
def check(self,
|
|
223
|
+
def check(self, name):
|
|
176
224
|
if not self.scope or self.in_scope:
|
|
177
|
-
return self.check_api_list(
|
|
225
|
+
return self.check_api_list(name)
|
|
178
226
|
return False
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class MixRangeScope(RangeScope):
|
|
230
|
+
def check_scope_is_valid(self):
|
|
231
|
+
return True if self.scope else False
|
|
232
|
+
|
|
233
|
+
def begin_module(self, module_name):
|
|
234
|
+
if self.scope and module_name == self.scope[0]:
|
|
235
|
+
self.in_scope = True
|
|
236
|
+
for name in self.api_list:
|
|
237
|
+
if name in module_name:
|
|
238
|
+
self.in_list = True
|
|
239
|
+
|
|
240
|
+
def end_module(self, module_name):
|
|
241
|
+
if self.scope and module_name == self.scope[1]:
|
|
242
|
+
self.in_scope = False
|
|
243
|
+
for name in self.api_list:
|
|
244
|
+
if name in module_name:
|
|
245
|
+
self.in_list = False
|
|
246
|
+
|
|
247
|
+
def check_api_list(self, api_name):
|
|
248
|
+
if not self.api_list:
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
for name in self.api_list:
|
|
252
|
+
if name in api_name:
|
|
253
|
+
return True
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
def check(self, name):
|
|
257
|
+
"""
|
|
258
|
+
dump时调用的接口,根据scope和api_list判断是否需要dump
|
|
259
|
+
"""
|
|
260
|
+
result = False
|
|
261
|
+
if self.scope and name == self.scope[0]:
|
|
262
|
+
self.in_scope = True
|
|
263
|
+
|
|
264
|
+
if not self.scope or self.in_scope:
|
|
265
|
+
if self.in_list:
|
|
266
|
+
result = True
|
|
267
|
+
else:
|
|
268
|
+
result = self.check_api_list(name)
|
|
269
|
+
|
|
270
|
+
if self.scope and name == self.scope[1]:
|
|
271
|
+
self.in_scope = False
|
|
272
|
+
return result
|
|
273
|
+
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
class GradConst:
|
|
3
17
|
|
|
@@ -33,6 +47,10 @@ class GradConst:
|
|
|
33
47
|
# direction suffix
|
|
34
48
|
DIR_SUFFIX = "dir.npy"
|
|
35
49
|
|
|
50
|
+
# bounds safety
|
|
51
|
+
BOUNDS_MINIMUM = -2**63
|
|
52
|
+
BOUNDS_MAXIMUM = 2**63 - 1
|
|
53
|
+
|
|
36
54
|
# file safty
|
|
37
55
|
DATA_DIR_AUTHORITY = 0o750
|
|
38
56
|
DATA_FILE_AUTHORITY = 0o640
|
|
@@ -56,16 +74,16 @@ class GradConst:
|
|
|
56
74
|
NORM = "norm"
|
|
57
75
|
|
|
58
76
|
level_adp = {
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
77
|
+
"L0": {
|
|
78
|
+
"header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
79
|
+
"have_grad_direction": False
|
|
80
|
+
},
|
|
81
|
+
"L1": {
|
|
82
|
+
"header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
83
|
+
"have_grad_direction": True
|
|
84
|
+
},
|
|
85
|
+
"L2": {
|
|
86
|
+
"header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
87
|
+
"have_grad_direction": True
|
|
88
|
+
},
|
|
89
|
+
}
|
|
@@ -1,13 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from typing import List
|
|
3
18
|
|
|
4
19
|
from tqdm import tqdm
|
|
5
|
-
import pandas as pd
|
|
6
20
|
import matplotlib.pyplot as plt
|
|
7
21
|
|
|
8
|
-
from msprobe.core.common.file_utils import create_directory,
|
|
22
|
+
from msprobe.core.common.file_utils import create_directory, check_file_or_directory_path
|
|
9
23
|
from msprobe.core.common.log import logger
|
|
10
|
-
from msprobe.core.common.file_utils import remove_path, load_npy, write_csv
|
|
24
|
+
from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv
|
|
11
25
|
from msprobe.core.grad_probe.constant import GradConst
|
|
12
26
|
from msprobe.core.grad_probe.utils import plt_savefig
|
|
13
27
|
|
|
@@ -21,7 +35,7 @@ class GradComparator:
|
|
|
21
35
|
continue
|
|
22
36
|
if not os.path.exists(os.path.join(path2, summary_file)):
|
|
23
37
|
continue
|
|
24
|
-
summary_csv =
|
|
38
|
+
summary_csv = read_csv(os.path.join(path1, summary_file))
|
|
25
39
|
return summary_csv["param_name"]
|
|
26
40
|
raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")
|
|
27
41
|
|
|
@@ -34,6 +48,8 @@ class GradComparator:
|
|
|
34
48
|
|
|
35
49
|
@classmethod
|
|
36
50
|
def compare_distributed(cls, path1: str, path2: str, output_dir: str):
|
|
51
|
+
check_file_or_directory_path(path1, isdir=True)
|
|
52
|
+
check_file_or_directory_path(path2, isdir=True)
|
|
37
53
|
ranks = cls._get_matched_dirs(path1, path2, "rank")
|
|
38
54
|
logger.info(f"the following ranks will be compared: {ranks}")
|
|
39
55
|
if not ranks:
|
msprobe/core/grad_probe/utils.py
CHANGED
|
@@ -1,8 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import re
|
|
2
17
|
from msprobe.core.grad_probe.constant import GradConst
|
|
3
18
|
from msprobe.core.common.log import logger
|
|
4
19
|
from msprobe.core.common.file_utils import write_csv, check_path_before_create, change_mode
|
|
5
20
|
from msprobe.core.common.const import FileCheckConst
|
|
21
|
+
from msprobe.core.common.utils import is_int
|
|
6
22
|
import matplotlib.pyplot as plt
|
|
7
23
|
|
|
8
24
|
|
|
@@ -20,12 +36,37 @@ def check_numeral_list_ascend(lst):
|
|
|
20
36
|
def check_param(param_name):
|
|
21
37
|
if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
|
|
22
38
|
raise RuntimeError("The parameter name contains special characters.")
|
|
23
|
-
|
|
39
|
+
|
|
24
40
|
|
|
25
41
|
def check_str(string, variable_name):
|
|
26
42
|
if not isinstance(string, str):
|
|
27
43
|
raise ValueError(f'The variable: "{variable_name}" is not a string.')
|
|
28
|
-
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def check_bounds_element(bound):
|
|
47
|
+
return GradConst.BOUNDS_MINIMUM <= bound <= GradConst.BOUNDS_MAXIMUM
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def check_param_element(param):
|
|
51
|
+
if not re.match(GradConst.PARAM_VALID_PATTERN, param):
|
|
52
|
+
return False
|
|
53
|
+
else:
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def check_bounds(bounds):
|
|
58
|
+
if not isinstance(bounds, list):
|
|
59
|
+
raise Exception(f"bounds must be a list")
|
|
60
|
+
prev = GradConst.BOUNDS_MINIMUM - 1
|
|
61
|
+
for element in bounds:
|
|
62
|
+
if not is_int(element) and not isinstance(element, float):
|
|
63
|
+
raise Exception("bounds element is not int or float")
|
|
64
|
+
if not check_bounds_element(element):
|
|
65
|
+
raise Exception("bounds element is out of int64 range")
|
|
66
|
+
if prev >= element:
|
|
67
|
+
raise Exception("bounds list is not ascending")
|
|
68
|
+
prev = element
|
|
69
|
+
|
|
29
70
|
|
|
30
71
|
class ListCache(list):
|
|
31
72
|
threshold = 1000
|
|
@@ -50,7 +91,7 @@ class ListCache(list):
|
|
|
50
91
|
list.append(self, data)
|
|
51
92
|
if len(self) >= ListCache.threshold:
|
|
52
93
|
self.flush()
|
|
53
|
-
|
|
94
|
+
|
|
54
95
|
def set_output_file(self, output_file):
|
|
55
96
|
self._output_file = output_file
|
|
56
97
|
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import List, Dict, Union, Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from msprobe.core.overflow_check.api_info import APIInfo
|
|
21
|
+
from msprobe.core.overflow_check.level import OverflowLevel
|
|
22
|
+
from msprobe.core.overflow_check.utils import has_nan_inf
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnomalyScene:
|
|
26
|
+
"""异常场景的基类"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, api_info: APIInfo):
|
|
29
|
+
self.api_name = api_info.api_name
|
|
30
|
+
self.api_data = api_info
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def rank(self) -> OverflowLevel:
|
|
34
|
+
"""获取异常等级"""
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _has_anomaly(data: Union[Dict, Any]) -> bool:
|
|
39
|
+
"""检查张量是否包含异常值"""
|
|
40
|
+
return has_nan_inf(data)
|
|
41
|
+
|
|
42
|
+
def get_details(self) -> Dict:
|
|
43
|
+
"""获取异常详情"""
|
|
44
|
+
return {
|
|
45
|
+
'api_name': self.api_name,
|
|
46
|
+
'rank': self.rank.value,
|
|
47
|
+
'scene_type': self.__class__.__name__,
|
|
48
|
+
'input_args_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.input_args),
|
|
49
|
+
'input_kwargs_anomaly_keys': self._get_anomaly_keys_from_dict(self.api_data.input_kwargs),
|
|
50
|
+
'output_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.output_data)
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def matches(self) -> bool:
|
|
54
|
+
"""
|
|
55
|
+
待子类实现对应匹配逻辑
|
|
56
|
+
Returns:
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
def _get_anomaly_indices_from_list(self, data_list: List[Dict]) -> List[int]:
|
|
62
|
+
return [i for i, data in enumerate(data_list) if self._has_anomaly(data)]
|
|
63
|
+
|
|
64
|
+
def _get_anomaly_keys_from_dict(self, data_dict: Dict) -> List[str]:
|
|
65
|
+
return [key for key, data in data_dict.items() if self._has_anomaly(data)]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class InputOutputAnomalyScene(AnomalyScene):
|
|
69
|
+
"""输入输出异常检测的基类"""
|
|
70
|
+
def has_input_anomaly(self) -> bool:
|
|
71
|
+
"""检查输入是否有异常(包括args和kwargs)"""
|
|
72
|
+
# args
|
|
73
|
+
args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args if isinstance(x, dict))
|
|
74
|
+
# kwargs
|
|
75
|
+
kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values() if isinstance(x, dict))
|
|
76
|
+
return args_anomaly or kwargs_anomaly
|
|
77
|
+
|
|
78
|
+
def has_output_anomaly(self) -> bool:
|
|
79
|
+
"""检查输出是否有异常"""
|
|
80
|
+
return any(self._has_anomaly(x) for x in self.api_data.output_data if isinstance(x, dict))
|
|
81
|
+
|
|
82
|
+
def matches(self) -> bool:
|
|
83
|
+
"""判断是否匹配该场景"""
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class InputAnomalyOutputNormalScene(InputOutputAnomalyScene):
|
|
88
|
+
"""输入异常,输出正常场景"""
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def rank(self) -> OverflowLevel:
|
|
92
|
+
return OverflowLevel.MEDIUM
|
|
93
|
+
|
|
94
|
+
def matches(self) -> bool:
|
|
95
|
+
return self.has_input_anomaly() and not self.has_output_anomaly()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class InputAnomalyOutputAnomalyScene(InputOutputAnomalyScene):
|
|
99
|
+
"""输入异常,输出异常场景"""
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def rank(self) -> OverflowLevel:
|
|
103
|
+
return OverflowLevel.HIGH
|
|
104
|
+
|
|
105
|
+
def matches(self) -> bool:
|
|
106
|
+
return self.has_input_anomaly() and self.has_output_anomaly()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class InputNormalOutputAnomalyScene(InputOutputAnomalyScene):
|
|
110
|
+
"""输入正常,输出异常场景"""
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def rank(self) -> OverflowLevel:
|
|
114
|
+
return OverflowLevel.CRITICAL
|
|
115
|
+
|
|
116
|
+
def matches(self) -> bool:
|
|
117
|
+
return not self.has_input_anomaly() and self.has_output_anomaly()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class NumericalMutationScene(AnomalyScene):
|
|
121
|
+
"""
|
|
122
|
+
检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
|
|
123
|
+
"""
|
|
124
|
+
def __init__(self, api_info: APIInfo, threshold: float = 100000.0):
|
|
125
|
+
super().__init__(api_info)
|
|
126
|
+
self.threshold = threshold
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def rank(self) -> OverflowLevel:
|
|
130
|
+
return OverflowLevel.HIGH
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _get_tensor_norms(data_list: List[Dict]) -> List[float]:
|
|
134
|
+
norms = []
|
|
135
|
+
for data in data_list:
|
|
136
|
+
if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
|
|
137
|
+
norm = data.get('Norm')
|
|
138
|
+
if norm is not None and not np.isnan(norm):
|
|
139
|
+
norms.append(norm)
|
|
140
|
+
return norms
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def _get_kwargs_norms(data_dict: Dict) -> List[float]:
|
|
144
|
+
"""
|
|
145
|
+
获取kwargs中张量的范数列表
|
|
146
|
+
Args:
|
|
147
|
+
data_dict:
|
|
148
|
+
Returns:
|
|
149
|
+
"""
|
|
150
|
+
norms = []
|
|
151
|
+
for data in data_dict.values():
|
|
152
|
+
if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
|
|
153
|
+
norm = data.get('Norm')
|
|
154
|
+
if norm is not None and not np.isnan(norm):
|
|
155
|
+
norms.append(norm)
|
|
156
|
+
return norms
|
|
157
|
+
|
|
158
|
+
def matches(self) -> bool:
|
|
159
|
+
"""
|
|
160
|
+
继承父类函数,实现数值突变检查
|
|
161
|
+
Returns:
|
|
162
|
+
"""
|
|
163
|
+
# 收集所有输入的范数
|
|
164
|
+
input_norms = (self._get_tensor_norms(self.api_data.input_args) +
|
|
165
|
+
self._get_kwargs_norms(self.api_data.input_kwargs))
|
|
166
|
+
# 收集所有输出的范数
|
|
167
|
+
output_norms = self._get_tensor_norms(self.api_data.output_data)
|
|
168
|
+
|
|
169
|
+
if not input_norms or not output_norms:
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
max_input = max(input_norms)
|
|
173
|
+
max_output = max(output_norms)
|
|
174
|
+
|
|
175
|
+
if max_input == 0:
|
|
176
|
+
return max_output > self.threshold
|
|
177
|
+
return max_output / max_input > self.threshold
|
|
178
|
+
|
|
179
|
+
def get_details(self) -> Dict:
|
|
180
|
+
details = super().get_details()
|
|
181
|
+
details.update({
|
|
182
|
+
'threshold': self.threshold,
|
|
183
|
+
'scale_change_detected': self.matches()
|
|
184
|
+
})
|
|
185
|
+
return details
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
|
|
18
|
+
from typing import Dict, List
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class APIInfo:
|
|
25
|
+
api_name: str
|
|
26
|
+
torch_api_name: str
|
|
27
|
+
input_args: List[Dict]
|
|
28
|
+
input_kwargs: Dict
|
|
29
|
+
output_data: List[Dict]
|
|
30
|
+
|
|
31
|
+
def __init__(self, api_name, input_args=None, input_kwargs=None, output_data=None):
|
|
32
|
+
self.api_name = api_name
|
|
33
|
+
self.input_args = input_args
|
|
34
|
+
self.input_kwargs = input_kwargs
|
|
35
|
+
self.output_data = output_data
|
|
36
|
+
self.torch_api_name = self.extract_torch_api(self.api_name)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def extract_torch_api(api_name) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Process tensor api name to extract first two fields in lowercase.
|
|
42
|
+
"""
|
|
43
|
+
# Empty string checking
|
|
44
|
+
if not api_name.strip():
|
|
45
|
+
return ""
|
|
46
|
+
|
|
47
|
+
parts = api_name.split(Const.SEP)
|
|
48
|
+
|
|
49
|
+
# Handle different cases based on number of parts
|
|
50
|
+
if len(parts) == 0:
|
|
51
|
+
return ""
|
|
52
|
+
elif len(parts) == 1:
|
|
53
|
+
return parts[0].lower()
|
|
54
|
+
else:
|
|
55
|
+
return Const.SEP.join(parts[:2]).lower()
|