mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
from collections import OrderedDict, defaultdict
|
|
20
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Dict, List, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import pytz
|
|
25
|
+
from msprobe.core.common.const import MonitorConst
|
|
26
|
+
from msprobe.core.common.file_utils import (create_directory, read_csv,
|
|
27
|
+
recursive_chmod, remove_path)
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import is_int
|
|
30
|
+
from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict
|
|
31
|
+
from msprobe.core.monitor.utils import get_target_output_dir
|
|
32
|
+
from tqdm import tqdm
|
|
33
|
+
|
|
34
|
+
# Constants
|
|
35
|
+
all_data_type_list = [
|
|
36
|
+
"actv", "actv_grad", "exp_avg", "exp_avg_sq",
|
|
37
|
+
"grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class CSV2DBConfig:
|
|
44
|
+
"""Configuration for CSV to database conversion"""
|
|
45
|
+
monitor_path: str
|
|
46
|
+
time_start: Optional[str] = None
|
|
47
|
+
time_end: Optional[str] = None
|
|
48
|
+
process_num: int = 1
|
|
49
|
+
data_type_list: Optional[List[str]] = None
|
|
50
|
+
output_dirpath: Optional[str] = None
|
|
51
|
+
step_partition: int = 500
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def validate_process_num(process_num: int) -> None:
|
|
55
|
+
"""Validate process number parameter"""
|
|
56
|
+
if not is_int(process_num) or process_num <= 0:
|
|
57
|
+
raise ValueError("process_num must be a positive integer")
|
|
58
|
+
if process_num > MonitorConst.MAX_PROCESS_NUM:
|
|
59
|
+
raise ValueError(f"Maximum supported process_num is {MonitorConst.MAX_PROCESS_NUM}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def validate_step_partition(step_partition: int) -> None:
|
|
63
|
+
if not isinstance(step_partition, int):
|
|
64
|
+
raise TypeError("step_partition must be integer")
|
|
65
|
+
if not MonitorConst.MIN_PARTITION <= step_partition <= MonitorConst.MAX_PARTITION:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"step_partition must be between {MonitorConst.MIN_PARTITION} ",
|
|
68
|
+
f"and {MonitorConst.MAX_PARTITION}, got {step_partition}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def validate_data_type_list(data_type_list: Optional[List[str]]) -> None:
|
|
73
|
+
"""Validate data type list parameter"""
|
|
74
|
+
if data_type_list is None or not data_type_list:
|
|
75
|
+
logger.info(f"Using default data types: {all_data_type_list}")
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if not isinstance(data_type_list, list):
|
|
79
|
+
raise ValueError("data_type_list must be a list")
|
|
80
|
+
|
|
81
|
+
invalid_types = [t for t in data_type_list if t not in all_data_type_list]
|
|
82
|
+
if invalid_types:
|
|
83
|
+
raise ValueError(f"Unsupported data types: {invalid_types}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_info_from_filename(file_name, metric_list=None):
|
|
87
|
+
metric_name = "_".join(file_name.split('_')[:-1])
|
|
88
|
+
if metric_list and metric_name not in metric_list:
|
|
89
|
+
return "", 0, 0
|
|
90
|
+
match = re.match(f"{metric_name}{MonitorConst.CSV_FILE_PATTERN}", file_name)
|
|
91
|
+
if not match:
|
|
92
|
+
return "", 0, 0
|
|
93
|
+
step_start, step_end = match.groups()
|
|
94
|
+
return metric_name, step_start, step_end
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict:
|
|
98
|
+
"""Pre-scan files for a single rank to collect metadata"""
|
|
99
|
+
metrics = set()
|
|
100
|
+
min_step = None
|
|
101
|
+
max_step = 0
|
|
102
|
+
metric_stats = defaultdict(set)
|
|
103
|
+
targets = OrderedDict()
|
|
104
|
+
|
|
105
|
+
for file_path in files:
|
|
106
|
+
file_name = os.path.basename(file_path)
|
|
107
|
+
metric_name, step_start, step_end = get_info_from_filename(file_name)
|
|
108
|
+
if not metric_name:
|
|
109
|
+
continue
|
|
110
|
+
step_start, step_end = int(step_start), int(step_end)
|
|
111
|
+
|
|
112
|
+
metrics.add(metric_name)
|
|
113
|
+
min_step = min(
|
|
114
|
+
step_start if min_step is None else min_step, step_start)
|
|
115
|
+
max_step = max(max_step, step_end)
|
|
116
|
+
|
|
117
|
+
data = read_csv(file_path)
|
|
118
|
+
stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED]
|
|
119
|
+
metric_stats[metric_name].update(stats)
|
|
120
|
+
|
|
121
|
+
for row_id, row in data.iterrows():
|
|
122
|
+
try:
|
|
123
|
+
name = row[MonitorConst.HEADER_NAME]
|
|
124
|
+
vpp_stage = int(row['vpp_stage'])
|
|
125
|
+
micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE))
|
|
126
|
+
except (ValueError, KeyError) as e:
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}")
|
|
129
|
+
continue
|
|
130
|
+
target = (name, vpp_stage, micro_step)
|
|
131
|
+
if target not in targets:
|
|
132
|
+
targets[target] = None
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
'max_rank': int(rank),
|
|
136
|
+
'metrics': metrics,
|
|
137
|
+
'min_step': min_step,
|
|
138
|
+
'max_step': max_step,
|
|
139
|
+
'metric_stats': metric_stats,
|
|
140
|
+
'targets': list(targets.keys())
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1):
|
|
145
|
+
"""Pre-scan all targets, metrics, and statistics"""
|
|
146
|
+
logger.info("Scanning dimensions...")
|
|
147
|
+
rank_files = defaultdict(list)
|
|
148
|
+
|
|
149
|
+
# Collect files for each rank
|
|
150
|
+
for rank, dir_path in data_dirs.items():
|
|
151
|
+
files = os.listdir(dir_path)
|
|
152
|
+
for file in files:
|
|
153
|
+
metric_name, _, _ = get_info_from_filename(
|
|
154
|
+
file, metric_list=data_type_list)
|
|
155
|
+
if not metric_name:
|
|
156
|
+
continue
|
|
157
|
+
rank_files[rank].append(os.path.join(dir_path, file))
|
|
158
|
+
|
|
159
|
+
# Parallel pre-scan
|
|
160
|
+
with ProcessPoolExecutor(max_workers=workers) as executor:
|
|
161
|
+
futures = {
|
|
162
|
+
executor.submit(_pre_scan_single_rank, rank, files): rank
|
|
163
|
+
for rank, files in rank_files.items()
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
results = []
|
|
167
|
+
with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar:
|
|
168
|
+
for future in as_completed(futures):
|
|
169
|
+
rank = futures[future]
|
|
170
|
+
try:
|
|
171
|
+
result = future.result()
|
|
172
|
+
results.append(result)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(
|
|
175
|
+
f"Error pre-scanning rank {rank}: {str(e)}")
|
|
176
|
+
pbar.update(1)
|
|
177
|
+
|
|
178
|
+
# Aggregate results
|
|
179
|
+
targets = OrderedDict()
|
|
180
|
+
metrics = set()
|
|
181
|
+
min_step = None
|
|
182
|
+
max_step = 0
|
|
183
|
+
max_rank = 0
|
|
184
|
+
metric_stats = defaultdict(set)
|
|
185
|
+
|
|
186
|
+
for rank_result in results:
|
|
187
|
+
max_rank = max(max_rank, rank_result['max_rank'])
|
|
188
|
+
metrics.update(rank_result['metrics'])
|
|
189
|
+
min_step = min(
|
|
190
|
+
min_step if min_step is not None else rank_result['min_step'],
|
|
191
|
+
rank_result['min_step']
|
|
192
|
+
)
|
|
193
|
+
max_step = max(max_step, rank_result['max_step'])
|
|
194
|
+
|
|
195
|
+
for metric, stats in rank_result['metric_stats'].items():
|
|
196
|
+
metric_stats[metric].update(stats)
|
|
197
|
+
|
|
198
|
+
targets = update_ordered_dict(targets, rank_result['targets'])
|
|
199
|
+
|
|
200
|
+
monitor_db.insert_dimensions(
|
|
201
|
+
targets, metrics, metric_stats, min_step=min_step, max_step=max_step)
|
|
202
|
+
monitor_db.update_global_stats(
|
|
203
|
+
max_rank=max_rank, min_step=min_step, max_step=max_step)
|
|
204
|
+
return rank_files
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def process_single_rank(
|
|
208
|
+
task: Tuple[int, List[str]],
|
|
209
|
+
metric_id_dict: Dict[str, Tuple[int, List[str]]],
|
|
210
|
+
target_dict: Dict[Tuple[str, int, int], int],
|
|
211
|
+
step_partition_size: int,
|
|
212
|
+
db_path: str
|
|
213
|
+
) -> int:
|
|
214
|
+
"""Process data import for a single rank"""
|
|
215
|
+
rank, files = task
|
|
216
|
+
db = MonitorDB(db_path, step_partition_size=step_partition_size)
|
|
217
|
+
total_inserted = 0
|
|
218
|
+
table_batches = defaultdict(list)
|
|
219
|
+
|
|
220
|
+
for file in files:
|
|
221
|
+
filename = os.path.basename(file)
|
|
222
|
+
metric_name, _, _ = get_info_from_filename(filename)
|
|
223
|
+
if not metric_name:
|
|
224
|
+
continue
|
|
225
|
+
metric_info = metric_id_dict.get(metric_name)
|
|
226
|
+
if not metric_info:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
metric_id, stats = metric_info
|
|
230
|
+
|
|
231
|
+
for row_id, row in read_csv(file).iterrows():
|
|
232
|
+
try:
|
|
233
|
+
# Parse row data
|
|
234
|
+
name = row.get(MonitorConst.HEADER_NAME)
|
|
235
|
+
vpp_stage = int(row['vpp_stage'])
|
|
236
|
+
micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE))
|
|
237
|
+
target_id = target_dict.get((name, vpp_stage, micro_step))
|
|
238
|
+
if not target_id:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
step = int(row['step'])
|
|
242
|
+
table_name, _, _ = db.get_metric_table_name(metric_id, step)
|
|
243
|
+
# Prepare row data
|
|
244
|
+
row_data = [rank, step, target_id]
|
|
245
|
+
row_data.extend(
|
|
246
|
+
float(row[stat]) if stat in row else None
|
|
247
|
+
for stat in stats
|
|
248
|
+
)
|
|
249
|
+
except (ValueError, KeyError) as e:
|
|
250
|
+
logger.error(
|
|
251
|
+
f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}")
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
table_batches[table_name].append(tuple(row_data))
|
|
255
|
+
# Batch insert when threshold reached
|
|
256
|
+
if len(table_batches[table_name]) >= MonitorConst.BATCH_SIZE:
|
|
257
|
+
inserted = db.insert_rows(
|
|
258
|
+
table_name, table_batches[table_name])
|
|
259
|
+
if inserted is not None:
|
|
260
|
+
total_inserted += inserted
|
|
261
|
+
table_batches[table_name] = []
|
|
262
|
+
|
|
263
|
+
# Insert remaining data
|
|
264
|
+
for table_name, batch in table_batches.items():
|
|
265
|
+
if batch:
|
|
266
|
+
inserted = db.insert_rows(table_name, batch)
|
|
267
|
+
if inserted is not None:
|
|
268
|
+
total_inserted += inserted
|
|
269
|
+
|
|
270
|
+
logger.info(f"Rank {rank} inserted {total_inserted} rows")
|
|
271
|
+
return total_inserted
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool:
|
|
275
|
+
"""Main method to import data into database"""
|
|
276
|
+
# 1. Pre-scan to get rank tasks
|
|
277
|
+
monitor_db.init_schema()
|
|
278
|
+
rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers)
|
|
279
|
+
if not rank_tasks:
|
|
280
|
+
logger.error("No valid data files found during pre-scan")
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
# 2. Get metric and target mappings
|
|
284
|
+
try:
|
|
285
|
+
metric_id_dict = monitor_db.get_metric_mapping()
|
|
286
|
+
target_dict = monitor_db.get_target_mapping()
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.error(f"Failed to get database mappings: {str(e)}")
|
|
289
|
+
return False
|
|
290
|
+
|
|
291
|
+
# 3. Process data for each rank in parallel
|
|
292
|
+
total_files = sum(len(files) for files in rank_tasks.values())
|
|
293
|
+
logger.info(f"Starting data import for {len(rank_tasks)} ranks,"
|
|
294
|
+
f"{total_files} files..."
|
|
295
|
+
)
|
|
296
|
+
all_succeeded = True
|
|
297
|
+
with ProcessPoolExecutor(max_workers=workers) as executor:
|
|
298
|
+
futures = {
|
|
299
|
+
executor.submit(
|
|
300
|
+
process_single_rank,
|
|
301
|
+
(rank, files),
|
|
302
|
+
metric_id_dict,
|
|
303
|
+
target_dict,
|
|
304
|
+
monitor_db.step_partition_size,
|
|
305
|
+
monitor_db.db_path): rank
|
|
306
|
+
for rank, files in rank_tasks.items()
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar:
|
|
310
|
+
for future in pbar:
|
|
311
|
+
rank = futures[future]
|
|
312
|
+
try:
|
|
313
|
+
inserted = future.result()
|
|
314
|
+
pbar.set_postfix_str(
|
|
315
|
+
f"Rank {rank}: inserted {inserted} rows")
|
|
316
|
+
except Exception as e:
|
|
317
|
+
logger.error(
|
|
318
|
+
f"Failed to process Rank {rank}: {str(e)}")
|
|
319
|
+
all_succeeded = False
|
|
320
|
+
return all_succeeded
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def csv2db(config: CSV2DBConfig) -> None:
|
|
324
|
+
"""Main function to convert CSV files to database"""
|
|
325
|
+
validate_process_num(config.process_num)
|
|
326
|
+
validate_step_partition(config.step_partition)
|
|
327
|
+
validate_data_type_list(config.data_type_list)
|
|
328
|
+
|
|
329
|
+
target_output_dirs = get_target_output_dir(
|
|
330
|
+
config.monitor_path, config.time_start, config.time_end)
|
|
331
|
+
|
|
332
|
+
if config.output_dirpath is None:
|
|
333
|
+
local_tz = pytz.timezone("Asia/Shanghai")
|
|
334
|
+
cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S")
|
|
335
|
+
config.output_dirpath = os.path.join(
|
|
336
|
+
config.monitor_path, f"{cur_time}-csv2db")
|
|
337
|
+
|
|
338
|
+
create_directory(config.output_dirpath)
|
|
339
|
+
db_path = os.path.join(config.output_dirpath, "monitor_metrics.db")
|
|
340
|
+
|
|
341
|
+
if os.path.exists(db_path):
|
|
342
|
+
remove_path(db_path)
|
|
343
|
+
logger.warning(f"Existing path {db_path} will be recovered")
|
|
344
|
+
|
|
345
|
+
db = MonitorDB(db_path, step_partition_size=config.step_partition)
|
|
346
|
+
|
|
347
|
+
result = import_data(
|
|
348
|
+
db,
|
|
349
|
+
target_output_dirs,
|
|
350
|
+
config.data_type_list if config.data_type_list else all_data_type_list,
|
|
351
|
+
workers=config.process_num
|
|
352
|
+
)
|
|
353
|
+
recursive_chmod(config.output_dirpath)
|
|
354
|
+
if result:
|
|
355
|
+
logger.info(
|
|
356
|
+
f"Data import completed. Output saved to: {config.output_dirpath}")
|
|
357
|
+
else:
|
|
358
|
+
logger.warning(
|
|
359
|
+
f"Data import may be incomplete. Output directory: {config.output_dirpath} "
|
|
360
|
+
f"(Some records might have failed)"
|
|
361
|
+
)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from collections import OrderedDict
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import MonitorConst
|
|
20
|
+
from msprobe.core.common.db_manager import DBManager
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict:
|
|
24
|
+
"""Update ordered dictionary with new items"""
|
|
25
|
+
for item in new_list:
|
|
26
|
+
if item not in main_dict:
|
|
27
|
+
main_dict[item] = None
|
|
28
|
+
return main_dict
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_ordered_stats(stats: Iterable) -> List[str]:
|
|
32
|
+
"""Get statistics in predefined order"""
|
|
33
|
+
if not isinstance(stats, Iterable):
|
|
34
|
+
return []
|
|
35
|
+
return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MonitorSql:
|
|
39
|
+
"""数据库表参数类"""
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def create_monitoring_targets_table():
|
|
43
|
+
"""监控目标表"""
|
|
44
|
+
return """
|
|
45
|
+
CREATE TABLE IF NOT EXISTS monitoring_targets (
|
|
46
|
+
target_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
47
|
+
target_name TEXT NOT NULL,
|
|
48
|
+
vpp_stage INTEGER NOT NULL,
|
|
49
|
+
micro_step INTEGER NOT NULL DEFAULT 0,
|
|
50
|
+
UNIQUE(target_name, vpp_stage, micro_step)
|
|
51
|
+
)"""
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def create_monitoring_metrics_table():
|
|
55
|
+
"""监控指标表"""
|
|
56
|
+
return """
|
|
57
|
+
CREATE TABLE IF NOT EXISTS monitoring_metrics (
|
|
58
|
+
metric_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
59
|
+
metric_name TEXT UNIQUE NOT NULL
|
|
60
|
+
)"""
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def get_metric_mapping_sql():
|
|
64
|
+
return """
|
|
65
|
+
SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats
|
|
66
|
+
FROM monitoring_metrics m
|
|
67
|
+
LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id
|
|
68
|
+
GROUP BY m.metric_id
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def create_metric_stats_table():
|
|
73
|
+
"""指标统计表"""
|
|
74
|
+
return """
|
|
75
|
+
CREATE TABLE IF NOT EXISTS metric_stats (
|
|
76
|
+
metric_id INTEGER NOT NULL,
|
|
77
|
+
stat_name TEXT NOT NULL,
|
|
78
|
+
PRIMARY KEY (metric_id, stat_name),
|
|
79
|
+
FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id)
|
|
80
|
+
) WITHOUT ROWID"""
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def create_global_stat_table():
|
|
84
|
+
return """
|
|
85
|
+
CREATE TABLE IF NOT EXISTS global_stats (
|
|
86
|
+
stat_name TEXT PRIMARY KEY,
|
|
87
|
+
stat_value INTEGER NOT NULL
|
|
88
|
+
) WITHOUT ROWID"""
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def get_table_definition(cls, table_name=""):
|
|
92
|
+
"""
|
|
93
|
+
获取表定义SQL
|
|
94
|
+
:param table_name: 表名
|
|
95
|
+
:return: 建表SQL语句
|
|
96
|
+
:raises ValueError: 当表名不存在时
|
|
97
|
+
"""
|
|
98
|
+
table_creators = {
|
|
99
|
+
"monitoring_targets": cls.create_monitoring_targets_table,
|
|
100
|
+
"monitoring_metrics": cls.create_monitoring_metrics_table,
|
|
101
|
+
"metric_stats": cls.create_metric_stats_table,
|
|
102
|
+
"global_stats": cls.create_global_stat_table,
|
|
103
|
+
}
|
|
104
|
+
if not table_name:
|
|
105
|
+
return [table_creators.get(table, lambda x: "")() for table in table_creators]
|
|
106
|
+
if table_name not in table_creators:
|
|
107
|
+
raise ValueError(f"Unsupported table name: {table_name}")
|
|
108
|
+
return table_creators[table_name]()
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def get_metric_table_definition(cls, table_name, stats, patition=None):
|
|
112
|
+
stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats]
|
|
113
|
+
if patition and len(patition) == 2:
|
|
114
|
+
partition_start_step, partition_end_step = patition
|
|
115
|
+
step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step}
|
|
116
|
+
AND {partition_end_step}),"""
|
|
117
|
+
else:
|
|
118
|
+
step_column = "step INTEGER NOT NULL"
|
|
119
|
+
create_sql = f"""
|
|
120
|
+
CREATE TABLE {table_name} (
|
|
121
|
+
rank INTEGER NOT NULL,
|
|
122
|
+
{step_column}
|
|
123
|
+
target_id INTEGER NOT NULL,
|
|
124
|
+
{', '.join(stat_columns)},
|
|
125
|
+
PRIMARY KEY (rank, step, target_id),
|
|
126
|
+
FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id)
|
|
127
|
+
) WITHOUT ROWID
|
|
128
|
+
"""
|
|
129
|
+
return create_sql
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class MonitorDB:
|
|
133
|
+
"""Main class for monitoring database operations"""
|
|
134
|
+
|
|
135
|
+
def __init__(self, db_path: str, step_partition_size: int = 500):
|
|
136
|
+
self.db_path = db_path
|
|
137
|
+
self.db_manager = DBManager(db_path)
|
|
138
|
+
self.step_partition_size = step_partition_size
|
|
139
|
+
|
|
140
|
+
def get_metric_table_name(self, metric_id: int, step: int) -> str:
|
|
141
|
+
"""Generate metric table name"""
|
|
142
|
+
step_start = (
|
|
143
|
+
step // self.step_partition_size) * self.step_partition_size
|
|
144
|
+
step_end = step_start + self.step_partition_size - 1
|
|
145
|
+
return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end
|
|
146
|
+
|
|
147
|
+
def init_schema(self) -> None:
|
|
148
|
+
"""Initialize database schema"""
|
|
149
|
+
self.db_manager.execute_multi_sql(MonitorSql.get_table_definition())
|
|
150
|
+
|
|
151
|
+
# Insert initial global stats
|
|
152
|
+
global_stats = [
|
|
153
|
+
('max_rank', 0),
|
|
154
|
+
('min_step', 0),
|
|
155
|
+
('max_step', 0),
|
|
156
|
+
('step_partition_size', self.step_partition_size)
|
|
157
|
+
]
|
|
158
|
+
self.db_manager.insert_data("global_stats", global_stats)
|
|
159
|
+
|
|
160
|
+
def insert_dimensions(
|
|
161
|
+
self,
|
|
162
|
+
targets: OrderedDict,
|
|
163
|
+
metrics: Set[str],
|
|
164
|
+
metric_stats: Dict[str, Set[str]],
|
|
165
|
+
min_step: Optional[int] = None,
|
|
166
|
+
max_step: int = None,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Insert dimension data into database"""
|
|
169
|
+
# Insert targets
|
|
170
|
+
self.db_manager.insert_data(
|
|
171
|
+
"monitoring_targets",
|
|
172
|
+
[(name, vpp_stage, micro_step)
|
|
173
|
+
for (name, vpp_stage, micro_step) in targets],
|
|
174
|
+
key_list=["target_name", "vpp_stage", "micro_step"]
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Insert metrics
|
|
178
|
+
self.db_manager.insert_data(
|
|
179
|
+
"monitoring_metrics",
|
|
180
|
+
[(metric,) for metric in metrics],
|
|
181
|
+
key_list=["metric_name"]
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Insert metric-stat relationships
|
|
185
|
+
for metric, stats in metric_stats.items():
|
|
186
|
+
metric_id = self._get_metric_id(metric)
|
|
187
|
+
ordered_stats = get_ordered_stats(stats)
|
|
188
|
+
|
|
189
|
+
self.db_manager.insert_data(
|
|
190
|
+
"metric_stats",
|
|
191
|
+
[(metric_id, stat) for stat in ordered_stats],
|
|
192
|
+
key_list=["metric_id", "stat_name"]
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Create metric tables for each partition
|
|
196
|
+
if min_step is not None and max_step is not None:
|
|
197
|
+
first_partition = min_step // self.step_partition_size
|
|
198
|
+
last_partition = max_step // self.step_partition_size
|
|
199
|
+
|
|
200
|
+
for partition in range(first_partition, last_partition + 1):
|
|
201
|
+
step_start = partition * self.step_partition_size
|
|
202
|
+
self.create_metric_table(
|
|
203
|
+
metric_id, step_start, ordered_stats)
|
|
204
|
+
|
|
205
|
+
def insert_rows(self, table_name, rows):
|
|
206
|
+
if not self.db_manager.table_exists(table_name):
|
|
207
|
+
raise RuntimeError(f"{table_name} not existed in {self.db_path}")
|
|
208
|
+
inserted = self.db_manager.insert_data(table_name, rows)
|
|
209
|
+
inserted = 0 if inserted is None else inserted
|
|
210
|
+
return inserted
|
|
211
|
+
|
|
212
|
+
def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str:
|
|
213
|
+
"""Create metric table for a specific partition"""
|
|
214
|
+
table_name, partition_start_step, partition_end_step = self.get_metric_table_name(
|
|
215
|
+
metric_id,
|
|
216
|
+
step
|
|
217
|
+
)
|
|
218
|
+
if self.db_manager.table_exists(table_name):
|
|
219
|
+
return table_name
|
|
220
|
+
|
|
221
|
+
create_sql = MonitorSql.get_metric_table_definition(
|
|
222
|
+
table_name, stats, patition=(
|
|
223
|
+
partition_start_step, partition_end_step)
|
|
224
|
+
)
|
|
225
|
+
self.db_manager.execute_sql(create_sql)
|
|
226
|
+
return table_name
|
|
227
|
+
|
|
228
|
+
def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None:
|
|
229
|
+
"""Update global statistics"""
|
|
230
|
+
updates = [
|
|
231
|
+
("max_rank", max_rank),
|
|
232
|
+
("min_step", min_step),
|
|
233
|
+
("max_step", max_step)
|
|
234
|
+
]
|
|
235
|
+
for stat_name, value in updates:
|
|
236
|
+
if not value:
|
|
237
|
+
continue
|
|
238
|
+
self.db_manager.update_data(
|
|
239
|
+
table_name="global_stats",
|
|
240
|
+
updates={"stat_value": value},
|
|
241
|
+
where={"stat_name": stat_name}
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]:
|
|
245
|
+
"""Get metric name to ID mapping with statistics"""
|
|
246
|
+
results = self.db_manager.execute_sql(
|
|
247
|
+
MonitorSql.get_metric_mapping_sql()
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return {
|
|
251
|
+
row["metric_name"]: (
|
|
252
|
+
row["metric_id"],
|
|
253
|
+
get_ordered_stats(row["stats"].split(",")
|
|
254
|
+
) if row["stats"] else []
|
|
255
|
+
) for row in results
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]:
|
|
259
|
+
"""Get target mapping dictionary"""
|
|
260
|
+
results = self.db_manager.select_data(
|
|
261
|
+
table_name="monitoring_targets",
|
|
262
|
+
columns=["target_id", "target_name", "vpp_stage", "micro_step"]
|
|
263
|
+
)
|
|
264
|
+
if not results:
|
|
265
|
+
return {}
|
|
266
|
+
return {
|
|
267
|
+
(row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"]
|
|
268
|
+
for row in results
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
def _get_metric_id(self, metric_name: str) -> Optional[int]:
|
|
272
|
+
"""Get metric ID by name"""
|
|
273
|
+
result = self.db_manager.select_data(
|
|
274
|
+
table_name="monitoring_metrics",
|
|
275
|
+
columns=["metric_id"],
|
|
276
|
+
where={"metric_name": metric_name}
|
|
277
|
+
)
|
|
278
|
+
return result[0]["metric_id"] if result else None
|
msprobe/core/monitor/utils.py
CHANGED
|
@@ -96,8 +96,33 @@ def validate_targets(targets):
|
|
|
96
96
|
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
97
97
|
if not isinstance(field, dict):
|
|
98
98
|
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
99
|
+
|
|
99
100
|
|
|
100
|
-
|
|
101
|
+
def validate_l2_targets(targets):
|
|
102
|
+
if not isinstance(targets, dict):
|
|
103
|
+
raise TypeError('l2_targets in config.json should be a dict')
|
|
104
|
+
for hook_name, target_list in targets.items():
|
|
105
|
+
if hook_name not in MonitorConst.L2_HOOKS:
|
|
106
|
+
raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}')
|
|
107
|
+
if not isinstance(target_list, list):
|
|
108
|
+
raise TypeError('values of l2_targets should be a list in config.json')
|
|
109
|
+
for item in target_list:
|
|
110
|
+
if not isinstance(item, str):
|
|
111
|
+
raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json')
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def validate_recording_l2_features(recording_l2_features):
|
|
115
|
+
if not isinstance(recording_l2_features, bool):
|
|
116
|
+
raise TypeError("recording_l2_features should be a bool")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def validate_sa_order(sa_order):
|
|
120
|
+
if isinstance(sa_order, str):
|
|
121
|
+
sa_order = sa_order.replace(' ', '')
|
|
122
|
+
if sa_order not in MonitorConst.SA_ORDERS:
|
|
123
|
+
raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}')
|
|
124
|
+
|
|
125
|
+
|
|
101
126
|
def validate_print_struct(print_struct):
|
|
102
127
|
if not isinstance(print_struct, bool):
|
|
103
128
|
raise TypeError("print_struct should be a bool")
|
|
@@ -216,6 +241,15 @@ def validate_config(config):
|
|
|
216
241
|
targets = config.get("targets", {})
|
|
217
242
|
validate_targets(targets)
|
|
218
243
|
|
|
244
|
+
l2_targets = config.get("l2_targets", {})
|
|
245
|
+
validate_l2_targets(l2_targets)
|
|
246
|
+
|
|
247
|
+
recording_l2_features = config.get("recording_l2_features", False)
|
|
248
|
+
validate_recording_l2_features(recording_l2_features)
|
|
249
|
+
|
|
250
|
+
sa_order = config.get("sa_order", "s,b,h,d")
|
|
251
|
+
validate_sa_order(sa_order)
|
|
252
|
+
|
|
219
253
|
print_struct = config.get('print_struct', False)
|
|
220
254
|
validate_print_struct(print_struct)
|
|
221
255
|
|