mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,6 +21,7 @@ from typing import List
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import torch
|
|
23
23
|
from torch import distributed as dist
|
|
24
|
+
from torch.distributed.distributed_c10d import _get_default_group
|
|
24
25
|
|
|
25
26
|
from msprobe.core.common.const import Const
|
|
26
27
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
@@ -40,7 +41,16 @@ except ImportError:
|
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
43
|
-
pytorch_special_type = (
|
|
44
|
+
pytorch_special_type = (
|
|
45
|
+
torch.device,
|
|
46
|
+
torch.dtype,
|
|
47
|
+
torch.Size,
|
|
48
|
+
torch.Tensor,
|
|
49
|
+
torch.memory_format,
|
|
50
|
+
dist.ProcessGroup,
|
|
51
|
+
dist.P2POp,
|
|
52
|
+
dist.ReduceOp
|
|
53
|
+
)
|
|
44
54
|
memory_format = {
|
|
45
55
|
torch.contiguous_format: "contiguous_format",
|
|
46
56
|
torch.channels_last: "channels_last",
|
|
@@ -54,6 +64,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
54
64
|
"device": self.analyze_device_in_kwargs,
|
|
55
65
|
"dtype": self.analyze_dtype_in_kwargs
|
|
56
66
|
}
|
|
67
|
+
self._async_dump_cache = {}
|
|
57
68
|
|
|
58
69
|
@staticmethod
|
|
59
70
|
def get_md5_for_tensor(x):
|
|
@@ -82,49 +93,80 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
82
93
|
return {"type": "torch.dtype", "value": str(element)}
|
|
83
94
|
|
|
84
95
|
@staticmethod
|
|
85
|
-
def
|
|
96
|
+
def get_stat_info_async(data):
|
|
86
97
|
tensor_stat = TensorStatInfo()
|
|
87
|
-
if data
|
|
88
|
-
|
|
89
|
-
data_clone = data.detach()
|
|
90
|
-
if data_clone.numel() == 0:
|
|
98
|
+
if torch.is_complex(data):
|
|
99
|
+
logger.warning("Async dump do not support complex data!")
|
|
91
100
|
return tensor_stat
|
|
92
|
-
elif
|
|
93
|
-
tensor_stat.
|
|
94
|
-
|
|
95
|
-
elif not
|
|
96
|
-
tensor_stat.
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
elif data.dtype == torch.bool:
|
|
102
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
|
|
103
|
+
[torch.any(data), torch.all(data)]))
|
|
104
|
+
elif not data.shape:
|
|
105
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
|
|
106
|
+
else:
|
|
107
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
108
|
+
data = data.float()
|
|
109
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
|
|
110
|
+
torch.max(data),
|
|
111
|
+
torch.min(data),
|
|
112
|
+
torch.mean(data),
|
|
113
|
+
torch.norm(data)
|
|
114
|
+
]))
|
|
115
|
+
return tensor_stat
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def get_stat_info_sync(data):
|
|
119
|
+
tensor_stat = TensorStatInfo()
|
|
120
|
+
if torch.is_complex(data):
|
|
121
|
+
data_np = data.cpu().numpy()
|
|
99
122
|
data_abs = np.abs(data_np)
|
|
100
123
|
tensor_stat.max = np.max(data_abs).item()
|
|
101
124
|
tensor_stat.min = np.min(data_abs).item()
|
|
102
125
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
126
|
+
elif data.dtype == torch.bool:
|
|
127
|
+
tensor_stat.max = torch.any(data).item()
|
|
128
|
+
tensor_stat.min = torch.all(data).item()
|
|
129
|
+
elif not data.shape:
|
|
130
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
103
131
|
else:
|
|
104
|
-
if not
|
|
105
|
-
|
|
106
|
-
tensor_stat.max = torch.
|
|
107
|
-
tensor_stat.min = torch.
|
|
108
|
-
tensor_stat.mean = torch.
|
|
109
|
-
tensor_stat.norm = torch.
|
|
132
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
133
|
+
data = data.float()
|
|
134
|
+
tensor_stat.max = torch.max(data).item()
|
|
135
|
+
tensor_stat.min = torch.min(data).item()
|
|
136
|
+
tensor_stat.mean = torch.mean(data).item()
|
|
137
|
+
tensor_stat.norm = torch.norm(data).item()
|
|
110
138
|
return tensor_stat
|
|
111
139
|
|
|
140
|
+
@staticmethod
|
|
141
|
+
def get_stat_info(data, async_dump=False):
|
|
142
|
+
tensor_stat = TensorStatInfo()
|
|
143
|
+
if data.is_meta:
|
|
144
|
+
return tensor_stat
|
|
145
|
+
data_clone = data.detach()
|
|
146
|
+
if data_clone.numel() == 0:
|
|
147
|
+
return tensor_stat
|
|
148
|
+
else:
|
|
149
|
+
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
150
|
+
return PytorchDataProcessor.get_stat_info_sync(data_clone)
|
|
151
|
+
else:
|
|
152
|
+
return PytorchDataProcessor.get_stat_info_async(data_clone)
|
|
153
|
+
|
|
112
154
|
@staticmethod
|
|
113
155
|
def handle_tensor_extremum_nan_inf(tensor, operator):
|
|
114
156
|
data_clone = tensor.detach()
|
|
115
|
-
data_nan = torch.
|
|
116
|
-
if int(torch.
|
|
157
|
+
data_nan = torch.isnan(data_clone)
|
|
158
|
+
if int(torch.sum(data_nan)) == data_clone.numel():
|
|
117
159
|
return float('nan')
|
|
118
160
|
|
|
119
|
-
finite_mask = torch.
|
|
120
|
-
if int(torch.
|
|
121
|
-
finite_values =
|
|
122
|
-
return torch.
|
|
123
|
-
torch.
|
|
161
|
+
finite_mask = torch.isfinite(data_clone)
|
|
162
|
+
if int(torch.sum(finite_mask)) > 0:
|
|
163
|
+
finite_values = data_clone[finite_mask]
|
|
164
|
+
return torch.max(finite_values).item() if operator == 'max' else \
|
|
165
|
+
torch.min(finite_values).item()
|
|
124
166
|
else:
|
|
125
|
-
data_no_nan =
|
|
126
|
-
return torch.
|
|
127
|
-
torch.
|
|
167
|
+
data_no_nan = data_clone[~data_nan]
|
|
168
|
+
return torch.max(data_no_nan).item() if operator == 'max' else \
|
|
169
|
+
torch.min(data_no_nan).item()
|
|
128
170
|
|
|
129
171
|
@staticmethod
|
|
130
172
|
def process_group_hash(arg):
|
|
@@ -132,6 +174,15 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
132
174
|
group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
|
|
133
175
|
return group_ranks_hash
|
|
134
176
|
|
|
177
|
+
@staticmethod
|
|
178
|
+
def is_distributed_op(module):
|
|
179
|
+
return getattr(module, "op_is_distributed", False)
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def is_hookable_element(element):
|
|
183
|
+
return (hasattr(element, "register_hook") and callable(element.register_hook)) and \
|
|
184
|
+
(hasattr(element, "requires_grad") and element.requires_grad)
|
|
185
|
+
|
|
135
186
|
@staticmethod
|
|
136
187
|
def _analyze_torch_size(arg):
|
|
137
188
|
return {"type": "torch.Size", "value": list(arg)}
|
|
@@ -140,7 +191,6 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
140
191
|
def _analyze_memory_format(arg):
|
|
141
192
|
# 获取内存格式
|
|
142
193
|
format_type = PytorchDataProcessor.memory_format.get(arg)
|
|
143
|
-
|
|
144
194
|
return {"type": "torch.memory_format", "format": format_type}
|
|
145
195
|
|
|
146
196
|
@staticmethod
|
|
@@ -152,9 +202,18 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
152
202
|
group_id = PytorchDataProcessor.process_group_hash(arg)
|
|
153
203
|
group_info.update({"group_id": group_id})
|
|
154
204
|
except Exception as e:
|
|
155
|
-
logger.warning(f"Failed to get process group
|
|
205
|
+
logger.warning(f"Failed to get process group ranks info with error info: {e}.")
|
|
156
206
|
return group_info
|
|
157
207
|
|
|
208
|
+
@staticmethod
|
|
209
|
+
def _analyze_reduce_op(arg):
|
|
210
|
+
op_type = None
|
|
211
|
+
try:
|
|
212
|
+
op_type = str(arg)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
|
|
215
|
+
return {"type": "torch.distributed.ReduceOp", "value": op_type}
|
|
216
|
+
|
|
158
217
|
@classmethod
|
|
159
218
|
def get_special_types(cls):
|
|
160
219
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -168,35 +227,65 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
168
227
|
return self._analyze_memory_format(element)
|
|
169
228
|
if isinstance(element, dist.ProcessGroup):
|
|
170
229
|
return self._analyze_process_group(element)
|
|
230
|
+
if isinstance(element, dist.P2POp):
|
|
231
|
+
return self._analyze_p2pop(element)
|
|
232
|
+
if isinstance(element, dist.ReduceOp):
|
|
233
|
+
return self._analyze_reduce_op(element)
|
|
171
234
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
172
235
|
if converted_numpy is not element:
|
|
173
|
-
return
|
|
236
|
+
return {"type": numpy_type, "value": converted_numpy}
|
|
174
237
|
if isinstance(element, torch.Tensor):
|
|
175
|
-
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
238
|
+
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
239
|
+
if isinstance(element, np.ndarray):
|
|
240
|
+
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
176
241
|
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
177
242
|
return self._analyze_builtin(element)
|
|
178
243
|
return {}
|
|
179
244
|
|
|
245
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
246
|
+
if self.is_distributed_op(module):
|
|
247
|
+
module_input_output.update_output_with_args_and_kwargs()
|
|
248
|
+
return super().analyze_forward_output(name, module, module_input_output)
|
|
249
|
+
|
|
250
|
+
def _analyze_p2pop(self, arg):
|
|
251
|
+
p2pop_info = {"class_type": "torch.distributed.P2POp"}
|
|
252
|
+
try:
|
|
253
|
+
tensor_info = self._analyze_tensor(arg.tensor, [])
|
|
254
|
+
p2pop_info.update({"tensor": tensor_info})
|
|
255
|
+
p2pop_info.update({"op": arg.op.__name__})
|
|
256
|
+
p2pop_info.update({"peer": arg.peer})
|
|
257
|
+
p2pop_info.update({"tag": arg.tag})
|
|
258
|
+
group_id = PytorchDataProcessor.process_group_hash(
|
|
259
|
+
arg.group) if arg.group else PytorchDataProcessor.process_group_hash(_get_default_group())
|
|
260
|
+
p2pop_info.update({"group_id": group_id})
|
|
261
|
+
except Exception as e:
|
|
262
|
+
logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
|
|
263
|
+
return p2pop_info
|
|
264
|
+
|
|
180
265
|
def _analyze_tensor(self, tensor, suffix):
|
|
181
|
-
tensor_stat = self.get_stat_info(tensor)
|
|
266
|
+
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
182
267
|
tensor_json = {}
|
|
183
268
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
184
269
|
tensor_json.update({'dtype': str(tensor.dtype)})
|
|
185
270
|
tensor_json.update({"shape": tensor.shape})
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
271
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
272
|
+
tensor_json.update({"Max": tensor_stat.max})
|
|
273
|
+
tensor_json.update({"Min": tensor_stat.min})
|
|
274
|
+
tensor_json.update({"Mean": tensor_stat.mean})
|
|
275
|
+
tensor_json.update({"Norm": tensor_stat.norm})
|
|
276
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
277
|
+
if tensor_stat.max is not None:
|
|
278
|
+
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
279
|
+
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
280
|
+
if tensor_stat.min is not None:
|
|
281
|
+
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
282
|
+
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
283
|
+
|
|
284
|
+
else:
|
|
285
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
286
|
+
tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
|
|
287
|
+
|
|
288
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
200
289
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
201
290
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
202
291
|
return tensor_json
|
|
@@ -207,13 +296,28 @@ class StatisticsDataProcessor(PytorchDataProcessor):
|
|
|
207
296
|
|
|
208
297
|
|
|
209
298
|
class TensorDataProcessor(PytorchDataProcessor):
|
|
299
|
+
def dump_async_data(self):
|
|
300
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
301
|
+
save_pt(tensor.contiguous(), file_path)
|
|
302
|
+
self._async_dump_cache.clear()
|
|
303
|
+
|
|
210
304
|
def _analyze_tensor(self, tensor, suffix):
|
|
211
305
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
212
|
-
saved_tensor = tensor.clone().contiguous().detach()
|
|
213
|
-
save_pt(saved_tensor, file_path)
|
|
214
306
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
215
307
|
single_arg.update({"data_name": dump_data_name})
|
|
308
|
+
if self.config.async_dump:
|
|
309
|
+
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
310
|
+
else:
|
|
311
|
+
saved_tensor = tensor.clone().contiguous().detach()
|
|
312
|
+
save_pt(saved_tensor, file_path)
|
|
216
313
|
return single_arg
|
|
314
|
+
|
|
315
|
+
def _analyze_numpy(self, ndarray, suffix):
|
|
316
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
317
|
+
save_pt(torch.tensor(ndarray), file_path)
|
|
318
|
+
ndarray_json = super()._analyze_numpy(ndarray, suffix)
|
|
319
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
320
|
+
return ndarray_json
|
|
217
321
|
|
|
218
322
|
|
|
219
323
|
class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
@@ -223,7 +327,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
223
327
|
super().__init__(config, data_writer)
|
|
224
328
|
self.has_overflow = False
|
|
225
329
|
self.support_inf_nan = None
|
|
226
|
-
self.
|
|
330
|
+
self.cached_api_info = {}
|
|
227
331
|
self.cached_tensors_and_file_paths = {}
|
|
228
332
|
self.bits_for_overflow = 8
|
|
229
333
|
self.real_overflow_nums = 0
|
|
@@ -237,21 +341,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
237
341
|
return True
|
|
238
342
|
return False
|
|
239
343
|
|
|
240
|
-
def
|
|
344
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
241
345
|
self.has_overflow = False
|
|
242
346
|
self._is_support_inf_nan()
|
|
243
|
-
self.
|
|
347
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
244
348
|
return None
|
|
245
349
|
|
|
246
|
-
def
|
|
350
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
247
351
|
self._is_support_inf_nan()
|
|
248
|
-
api_info_struct = super().
|
|
249
|
-
if name in self.
|
|
250
|
-
self.
|
|
352
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
353
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
354
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
251
355
|
elif name in api_info_struct:
|
|
252
|
-
self.
|
|
356
|
+
self.cached_api_info = api_info_struct
|
|
253
357
|
self.handle_overflow()
|
|
254
|
-
return self.
|
|
358
|
+
return self.cached_api_info if self.has_overflow else None
|
|
255
359
|
|
|
256
360
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
257
361
|
self.has_overflow = False
|
|
@@ -267,6 +371,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
267
371
|
self.handle_overflow()
|
|
268
372
|
return api_info_struct if self.has_overflow else None
|
|
269
373
|
|
|
374
|
+
def analyze_params(self, name, param_name, grad):
|
|
375
|
+
self.has_overflow = False
|
|
376
|
+
self._is_support_inf_nan()
|
|
377
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
378
|
+
self.handle_overflow()
|
|
379
|
+
return api_info_struct if self.has_overflow else None
|
|
380
|
+
|
|
270
381
|
def handle_overflow(self):
|
|
271
382
|
if not self.support_inf_nan:
|
|
272
383
|
self._analyze_maybe_overflow_flag()
|
|
@@ -340,10 +451,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
340
451
|
)
|
|
341
452
|
return
|
|
342
453
|
|
|
343
|
-
def
|
|
454
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
344
455
|
self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
|
|
345
456
|
|
|
346
|
-
def
|
|
457
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
347
458
|
new_output, unequal_rows = self.checker.forward(
|
|
348
459
|
name,
|
|
349
460
|
module,
|
|
@@ -388,7 +499,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
388
499
|
def _print_unsupported_log(api_name):
|
|
389
500
|
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
390
501
|
|
|
391
|
-
def
|
|
502
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
392
503
|
if not self.enable_kernel_dump:
|
|
393
504
|
return
|
|
394
505
|
if is_gpu:
|
|
@@ -413,7 +524,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
413
524
|
return
|
|
414
525
|
self.start_kernel_dump(self.config.kernel_config_path)
|
|
415
526
|
|
|
416
|
-
def
|
|
527
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
417
528
|
if not self.enable_kernel_dump:
|
|
418
529
|
return
|
|
419
530
|
if self.config.is_backward_kernel_dump:
|
|
@@ -15,10 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
|
+
import copy
|
|
19
|
+
import numpy as np
|
|
18
20
|
|
|
19
21
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
|
|
22
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
21
23
|
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class DataWriter:
|
|
@@ -29,10 +32,12 @@ class DataWriter:
|
|
|
29
32
|
self.construct_file_path = None
|
|
30
33
|
self.free_benchmark_file_path = None
|
|
31
34
|
self.dump_tensor_data_dir = None
|
|
35
|
+
self.debug_file_path = None
|
|
32
36
|
self.flush_size = 1000
|
|
33
37
|
self.cache_data = {}
|
|
34
38
|
self.cache_stack = {}
|
|
35
39
|
self.cache_construct = {}
|
|
40
|
+
self.cache_debug = {}
|
|
36
41
|
|
|
37
42
|
@staticmethod
|
|
38
43
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -55,6 +60,13 @@ class DataWriter:
|
|
|
55
60
|
self.cache_construct = {}
|
|
56
61
|
|
|
57
62
|
def initialize_json_file(self, **kwargs):
|
|
63
|
+
if self.debug_file_path and not self.cache_debug:
|
|
64
|
+
# debug level case only create debug.json
|
|
65
|
+
debug_dict = copy.deepcopy(kwargs)
|
|
66
|
+
debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
67
|
+
self.cache_debug = debug_dict
|
|
68
|
+
save_json(self.debug_file_path, self.cache_debug, indent=1)
|
|
69
|
+
return
|
|
58
70
|
if not self.cache_data:
|
|
59
71
|
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
60
72
|
self.cache_data = kwargs
|
|
@@ -64,13 +76,13 @@ class DataWriter:
|
|
|
64
76
|
if not self.cache_construct:
|
|
65
77
|
save_json(self.construct_file_path, self.cache_construct, indent=1)
|
|
66
78
|
|
|
67
|
-
def update_dump_paths(self,
|
|
68
|
-
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
79
|
+
def update_dump_paths(self, dump_path_aggregation):
|
|
80
|
+
self.dump_file_path = dump_path_aggregation.dump_file_path
|
|
81
|
+
self.stack_file_path = dump_path_aggregation.stack_file_path
|
|
82
|
+
self.construct_file_path = dump_path_aggregation.construct_file_path
|
|
83
|
+
self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
|
|
84
|
+
self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
|
|
85
|
+
self.debug_file_path = dump_path_aggregation.debug_file_path
|
|
74
86
|
|
|
75
87
|
def flush_data_periodically(self):
|
|
76
88
|
dump_data = self.cache_data.get(Const.DATA)
|
|
@@ -98,6 +110,9 @@ class DataWriter:
|
|
|
98
110
|
def update_construct(self, new_data):
|
|
99
111
|
self.cache_construct.update(new_data)
|
|
100
112
|
|
|
113
|
+
def update_debug(self, new_data):
|
|
114
|
+
self.cache_debug['data'].update(new_data)
|
|
115
|
+
|
|
101
116
|
def write_data_json(self, file_path):
|
|
102
117
|
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
103
118
|
save_json(file_path, self.cache_data, indent=1)
|
|
@@ -108,6 +123,9 @@ class DataWriter:
|
|
|
108
123
|
def write_construct_info_json(self, file_path):
|
|
109
124
|
save_json(file_path, self.cache_construct, indent=1)
|
|
110
125
|
|
|
126
|
+
def write_debug_info_json(self, file_path):
|
|
127
|
+
save_json(file_path, self.cache_debug, indent=1)
|
|
128
|
+
|
|
111
129
|
def write_json(self):
|
|
112
130
|
if self.cache_data:
|
|
113
131
|
self.write_data_json(self.dump_file_path)
|
|
@@ -115,3 +133,31 @@ class DataWriter:
|
|
|
115
133
|
self.write_stack_info_json(self.stack_file_path)
|
|
116
134
|
if self.cache_construct:
|
|
117
135
|
self.write_construct_info_json(self.construct_file_path)
|
|
136
|
+
if self.cache_debug:
|
|
137
|
+
self.write_debug_info_json(self.debug_file_path)
|
|
138
|
+
|
|
139
|
+
def fill_stack_tensor_data(self):
|
|
140
|
+
self.process_stat_data_recursive(self.cache_data)
|
|
141
|
+
|
|
142
|
+
def process_stat_data_recursive(self, data, depth=0):
|
|
143
|
+
if depth > Const.MAX_DEPTH:
|
|
144
|
+
logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
|
|
145
|
+
raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
|
|
146
|
+
if isinstance(data, dict):
|
|
147
|
+
if "tensor_stat" in data.keys():
|
|
148
|
+
tensor_stat = data["tensor_stat"]
|
|
149
|
+
if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
|
|
150
|
+
logger.warning("Some bad data in async dump")
|
|
151
|
+
else:
|
|
152
|
+
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
153
|
+
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
154
|
+
tensor_stat_data = tensor_stat_data.cpu()
|
|
155
|
+
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
156
|
+
data.update({index: stat.item()})
|
|
157
|
+
del data["tensor_stat"]
|
|
158
|
+
else:
|
|
159
|
+
for key in data.keys():
|
|
160
|
+
self.process_stat_data_recursive(data[key], depth + 1)
|
|
161
|
+
elif isinstance(data, (list, tuple)):
|
|
162
|
+
for i in data:
|
|
163
|
+
self.process_stat_data_recursive(i, depth + 1)
|
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -45,7 +45,7 @@ class ScopeFactory:
|
|
|
45
45
|
|
|
46
46
|
if self.level == Const.LEVEL_MIX:
|
|
47
47
|
return mix_range_scope
|
|
48
|
-
|
|
48
|
+
|
|
49
49
|
if not self.scope:
|
|
50
50
|
return api_range_scope
|
|
51
51
|
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
@@ -73,21 +73,21 @@ class BaseScope(ABC):
|
|
|
73
73
|
def rectify_args(scope, api_list):
|
|
74
74
|
if not isinstance(api_list, list):
|
|
75
75
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
76
|
-
|
|
76
|
+
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
77
77
|
for api in api_list:
|
|
78
78
|
if not isinstance(api, str):
|
|
79
79
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
80
|
-
|
|
80
|
+
f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
|
|
81
81
|
if isinstance(scope, str):
|
|
82
82
|
scope = [scope]
|
|
83
83
|
return scope, api_list
|
|
84
84
|
if not isinstance(scope, list):
|
|
85
85
|
raise ScopeException(ScopeException.InvalidScope,
|
|
86
|
-
|
|
86
|
+
f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
|
|
87
87
|
for s in scope:
|
|
88
88
|
if not isinstance(s, str):
|
|
89
89
|
raise ScopeException(ScopeException.InvalidScope,
|
|
90
|
-
|
|
90
|
+
f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
|
|
91
91
|
return scope, api_list
|
|
92
92
|
|
|
93
93
|
@abstractmethod
|
|
@@ -108,7 +108,7 @@ class ListScope(BaseScope):
|
|
|
108
108
|
def rectify_args(scope, api_list):
|
|
109
109
|
if scope and api_list:
|
|
110
110
|
raise ScopeException(ScopeException.ArgConflict,
|
|
111
|
-
|
|
111
|
+
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
112
112
|
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
113
113
|
|
|
114
114
|
def check(self, name):
|
|
@@ -123,6 +123,7 @@ class RangeScope(BaseScope, ABC):
|
|
|
123
123
|
super().__init__(*args)
|
|
124
124
|
self.in_scope = False
|
|
125
125
|
self.in_list = False
|
|
126
|
+
self.start_name_set = set()
|
|
126
127
|
self.is_valid = self.check_scope_is_valid()
|
|
127
128
|
|
|
128
129
|
def check_name_pattern(self, name):
|
|
@@ -133,23 +134,23 @@ class RangeScope(BaseScope, ABC):
|
|
|
133
134
|
if self.level == Const.LEVEL_L1:
|
|
134
135
|
if not re.match(api_pattern, name):
|
|
135
136
|
raise ScopeException(ScopeException.InvalidScope,
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
|
|
138
|
+
|
|
138
139
|
if self.level == Const.LEVEL_L0:
|
|
139
140
|
if not re.match(module_pattern, name):
|
|
140
141
|
raise ScopeException(ScopeException.InvalidScope,
|
|
141
|
-
|
|
142
|
+
f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
|
|
142
143
|
|
|
143
144
|
if self.level == Const.LEVEL_MIX:
|
|
144
145
|
if not re.match(api_pattern, name) and not re.match(module_pattern, name):
|
|
145
146
|
raise ScopeException(ScopeException.InvalidScope,
|
|
146
|
-
|
|
147
|
+
f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
|
|
147
148
|
|
|
148
149
|
def rectify_args(self, scope, api_list):
|
|
149
150
|
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
150
151
|
if scope and len(scope) != 2:
|
|
151
152
|
raise ScopeException(ScopeException.InvalidScope,
|
|
152
|
-
|
|
153
|
+
f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
|
|
153
154
|
for name in scope:
|
|
154
155
|
self.check_name_pattern(name)
|
|
155
156
|
return scope, api_list
|
|
@@ -229,30 +230,31 @@ class ModuleRangeScope(RangeScope):
|
|
|
229
230
|
class MixRangeScope(RangeScope):
|
|
230
231
|
def check_scope_is_valid(self):
|
|
231
232
|
return True if self.scope else False
|
|
232
|
-
|
|
233
|
+
|
|
233
234
|
def begin_module(self, module_name):
|
|
234
235
|
if self.scope and module_name == self.scope[0]:
|
|
235
236
|
self.in_scope = True
|
|
236
237
|
for name in self.api_list:
|
|
237
238
|
if name in module_name:
|
|
238
239
|
self.in_list = True
|
|
240
|
+
self.start_name_set.add(module_name) # 记录每一个开启in_list的module_name
|
|
239
241
|
|
|
240
242
|
def end_module(self, module_name):
|
|
241
243
|
if self.scope and module_name == self.scope[1]:
|
|
242
244
|
self.in_scope = False
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
245
|
+
self.start_name_set.discard(module_name) # 从集合中删除每一个module_name
|
|
246
|
+
if not self.start_name_set: # 如果集合为空,说明当前module_name是最后一个开启in_list的module_name
|
|
247
|
+
self.in_list = False # 关闭in_list
|
|
246
248
|
|
|
247
249
|
def check_api_list(self, api_name):
|
|
248
250
|
if not self.api_list:
|
|
249
251
|
return True
|
|
250
|
-
|
|
252
|
+
|
|
251
253
|
for name in self.api_list:
|
|
252
254
|
if name in api_name:
|
|
253
255
|
return True
|
|
254
256
|
return False
|
|
255
|
-
|
|
257
|
+
|
|
256
258
|
def check(self, name):
|
|
257
259
|
"""
|
|
258
260
|
dump时调用的接口,根据scope和api_list判断是否需要dump
|
|
@@ -270,4 +272,3 @@ class MixRangeScope(RangeScope):
|
|
|
270
272
|
if self.scope and name == self.scope[1]:
|
|
271
273
|
self.in_scope = False
|
|
272
274
|
return result
|
|
273
|
-
|
|
@@ -37,7 +37,11 @@ class AnomalyScene:
|
|
|
37
37
|
@staticmethod
|
|
38
38
|
def _has_anomaly(data: Union[Dict, Any]) -> bool:
|
|
39
39
|
"""检查张量是否包含异常值"""
|
|
40
|
-
|
|
40
|
+
if isinstance(data, dict):
|
|
41
|
+
return has_nan_inf(data)
|
|
42
|
+
elif isinstance(data, list):
|
|
43
|
+
return any(AnomalyScene._has_anomaly(x) for x in data)
|
|
44
|
+
return False
|
|
41
45
|
|
|
42
46
|
def get_details(self) -> Dict:
|
|
43
47
|
"""获取异常详情"""
|
|
@@ -70,14 +74,14 @@ class InputOutputAnomalyScene(AnomalyScene):
|
|
|
70
74
|
def has_input_anomaly(self) -> bool:
|
|
71
75
|
"""检查输入是否有异常(包括args和kwargs)"""
|
|
72
76
|
# args
|
|
73
|
-
args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args
|
|
77
|
+
args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args)
|
|
74
78
|
# kwargs
|
|
75
|
-
kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values()
|
|
79
|
+
kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values())
|
|
76
80
|
return args_anomaly or kwargs_anomaly
|
|
77
81
|
|
|
78
82
|
def has_output_anomaly(self) -> bool:
|
|
79
83
|
"""检查输出是否有异常"""
|
|
80
|
-
return any(self._has_anomaly(x) for x in self.api_data.output_data
|
|
84
|
+
return any(self._has_anomaly(x) for x in self.api_data.output_data)
|
|
81
85
|
|
|
82
86
|
def matches(self) -> bool:
|
|
83
87
|
"""判断是否匹配该场景"""
|
|
@@ -121,7 +125,7 @@ class NumericalMutationScene(AnomalyScene):
|
|
|
121
125
|
"""
|
|
122
126
|
检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
|
|
123
127
|
"""
|
|
124
|
-
def __init__(self, api_info: APIInfo, threshold: float =
|
|
128
|
+
def __init__(self, api_info: APIInfo, threshold: float = 100.0):
|
|
125
129
|
super().__init__(api_info)
|
|
126
130
|
self.threshold = threshold
|
|
127
131
|
|