mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -0,0 +1,361 @@
1
+ # Copyright (c) 2025-2026, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import datetime
17
+ import os
18
+ import re
19
+ from collections import OrderedDict, defaultdict
20
+ from concurrent.futures import ProcessPoolExecutor, as_completed
21
+ from dataclasses import dataclass
22
+ from typing import Dict, List, Optional, Tuple
23
+
24
+ import pytz
25
+ from msprobe.core.common.const import MonitorConst
26
+ from msprobe.core.common.file_utils import (create_directory, read_csv,
27
+ recursive_chmod, remove_path)
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import is_int
30
+ from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict
31
+ from msprobe.core.monitor.utils import get_target_output_dir
32
+ from tqdm import tqdm
33
+
34
+ # Constants
35
+ all_data_type_list = [
36
+ "actv", "actv_grad", "exp_avg", "exp_avg_sq",
37
+ "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"
38
+ ]
39
+
40
+
41
+
42
+ @dataclass
43
+ class CSV2DBConfig:
44
+ """Configuration for CSV to database conversion"""
45
+ monitor_path: str
46
+ time_start: Optional[str] = None
47
+ time_end: Optional[str] = None
48
+ process_num: int = 1
49
+ data_type_list: Optional[List[str]] = None
50
+ output_dirpath: Optional[str] = None
51
+ step_partition: int = 500
52
+
53
+
54
+ def validate_process_num(process_num: int) -> None:
55
+ """Validate process number parameter"""
56
+ if not is_int(process_num) or process_num <= 0:
57
+ raise ValueError("process_num must be a positive integer")
58
+ if process_num > MonitorConst.MAX_PROCESS_NUM:
59
+ raise ValueError(f"Maximum supported process_num is {MonitorConst.MAX_PROCESS_NUM}")
60
+
61
+
62
+ def validate_step_partition(step_partition: int) -> None:
63
+ if not isinstance(step_partition, int):
64
+ raise TypeError("step_partition must be integer")
65
+ if not MonitorConst.MIN_PARTITION <= step_partition <= MonitorConst.MAX_PARTITION:
66
+ raise ValueError(
67
+ f"step_partition must be between {MonitorConst.MIN_PARTITION} ",
68
+ f"and {MonitorConst.MAX_PARTITION}, got {step_partition}"
69
+ )
70
+
71
+
72
+ def validate_data_type_list(data_type_list: Optional[List[str]]) -> None:
73
+ """Validate data type list parameter"""
74
+ if data_type_list is None or not data_type_list:
75
+ logger.info(f"Using default data types: {all_data_type_list}")
76
+ return
77
+
78
+ if not isinstance(data_type_list, list):
79
+ raise ValueError("data_type_list must be a list")
80
+
81
+ invalid_types = [t for t in data_type_list if t not in all_data_type_list]
82
+ if invalid_types:
83
+ raise ValueError(f"Unsupported data types: {invalid_types}")
84
+
85
+
86
+ def get_info_from_filename(file_name, metric_list=None):
87
+ metric_name = "_".join(file_name.split('_')[:-1])
88
+ if metric_list and metric_name not in metric_list:
89
+ return "", 0, 0
90
+ match = re.match(f"{metric_name}{MonitorConst.CSV_FILE_PATTERN}", file_name)
91
+ if not match:
92
+ return "", 0, 0
93
+ step_start, step_end = match.groups()
94
+ return metric_name, step_start, step_end
95
+
96
+
97
+ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict:
98
+ """Pre-scan files for a single rank to collect metadata"""
99
+ metrics = set()
100
+ min_step = None
101
+ max_step = 0
102
+ metric_stats = defaultdict(set)
103
+ targets = OrderedDict()
104
+
105
+ for file_path in files:
106
+ file_name = os.path.basename(file_path)
107
+ metric_name, step_start, step_end = get_info_from_filename(file_name)
108
+ if not metric_name:
109
+ continue
110
+ step_start, step_end = int(step_start), int(step_end)
111
+
112
+ metrics.add(metric_name)
113
+ min_step = min(
114
+ step_start if min_step is None else min_step, step_start)
115
+ max_step = max(max_step, step_end)
116
+
117
+ data = read_csv(file_path)
118
+ stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED]
119
+ metric_stats[metric_name].update(stats)
120
+
121
+ for row_id, row in data.iterrows():
122
+ try:
123
+ name = row[MonitorConst.HEADER_NAME]
124
+ vpp_stage = int(row['vpp_stage'])
125
+ micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE))
126
+ except (ValueError, KeyError) as e:
127
+ logger.warning(
128
+ f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}")
129
+ continue
130
+ target = (name, vpp_stage, micro_step)
131
+ if target not in targets:
132
+ targets[target] = None
133
+
134
+ return {
135
+ 'max_rank': int(rank),
136
+ 'metrics': metrics,
137
+ 'min_step': min_step,
138
+ 'max_step': max_step,
139
+ 'metric_stats': metric_stats,
140
+ 'targets': list(targets.keys())
141
+ }
142
+
143
+
144
+ def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1):
145
+ """Pre-scan all targets, metrics, and statistics"""
146
+ logger.info("Scanning dimensions...")
147
+ rank_files = defaultdict(list)
148
+
149
+ # Collect files for each rank
150
+ for rank, dir_path in data_dirs.items():
151
+ files = os.listdir(dir_path)
152
+ for file in files:
153
+ metric_name, _, _ = get_info_from_filename(
154
+ file, metric_list=data_type_list)
155
+ if not metric_name:
156
+ continue
157
+ rank_files[rank].append(os.path.join(dir_path, file))
158
+
159
+ # Parallel pre-scan
160
+ with ProcessPoolExecutor(max_workers=workers) as executor:
161
+ futures = {
162
+ executor.submit(_pre_scan_single_rank, rank, files): rank
163
+ for rank, files in rank_files.items()
164
+ }
165
+
166
+ results = []
167
+ with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar:
168
+ for future in as_completed(futures):
169
+ rank = futures[future]
170
+ try:
171
+ result = future.result()
172
+ results.append(result)
173
+ except Exception as e:
174
+ logger.error(
175
+ f"Error pre-scanning rank {rank}: {str(e)}")
176
+ pbar.update(1)
177
+
178
+ # Aggregate results
179
+ targets = OrderedDict()
180
+ metrics = set()
181
+ min_step = None
182
+ max_step = 0
183
+ max_rank = 0
184
+ metric_stats = defaultdict(set)
185
+
186
+ for rank_result in results:
187
+ max_rank = max(max_rank, rank_result['max_rank'])
188
+ metrics.update(rank_result['metrics'])
189
+ min_step = min(
190
+ min_step if min_step is not None else rank_result['min_step'],
191
+ rank_result['min_step']
192
+ )
193
+ max_step = max(max_step, rank_result['max_step'])
194
+
195
+ for metric, stats in rank_result['metric_stats'].items():
196
+ metric_stats[metric].update(stats)
197
+
198
+ targets = update_ordered_dict(targets, rank_result['targets'])
199
+
200
+ monitor_db.insert_dimensions(
201
+ targets, metrics, metric_stats, min_step=min_step, max_step=max_step)
202
+ monitor_db.update_global_stats(
203
+ max_rank=max_rank, min_step=min_step, max_step=max_step)
204
+ return rank_files
205
+
206
+
207
+ def process_single_rank(
208
+ task: Tuple[int, List[str]],
209
+ metric_id_dict: Dict[str, Tuple[int, List[str]]],
210
+ target_dict: Dict[Tuple[str, int, int], int],
211
+ step_partition_size: int,
212
+ db_path: str
213
+ ) -> int:
214
+ """Process data import for a single rank"""
215
+ rank, files = task
216
+ db = MonitorDB(db_path, step_partition_size=step_partition_size)
217
+ total_inserted = 0
218
+ table_batches = defaultdict(list)
219
+
220
+ for file in files:
221
+ filename = os.path.basename(file)
222
+ metric_name, _, _ = get_info_from_filename(filename)
223
+ if not metric_name:
224
+ continue
225
+ metric_info = metric_id_dict.get(metric_name)
226
+ if not metric_info:
227
+ continue
228
+
229
+ metric_id, stats = metric_info
230
+
231
+ for row_id, row in read_csv(file).iterrows():
232
+ try:
233
+ # Parse row data
234
+ name = row.get(MonitorConst.HEADER_NAME)
235
+ vpp_stage = int(row['vpp_stage'])
236
+ micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE))
237
+ target_id = target_dict.get((name, vpp_stage, micro_step))
238
+ if not target_id:
239
+ continue
240
+
241
+ step = int(row['step'])
242
+ table_name, _, _ = db.get_metric_table_name(metric_id, step)
243
+ # Prepare row data
244
+ row_data = [rank, step, target_id]
245
+ row_data.extend(
246
+ float(row[stat]) if stat in row else None
247
+ for stat in stats
248
+ )
249
+ except (ValueError, KeyError) as e:
250
+ logger.error(
251
+ f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}")
252
+ continue
253
+
254
+ table_batches[table_name].append(tuple(row_data))
255
+ # Batch insert when threshold reached
256
+ if len(table_batches[table_name]) >= MonitorConst.BATCH_SIZE:
257
+ inserted = db.insert_rows(
258
+ table_name, table_batches[table_name])
259
+ if inserted is not None:
260
+ total_inserted += inserted
261
+ table_batches[table_name] = []
262
+
263
+ # Insert remaining data
264
+ for table_name, batch in table_batches.items():
265
+ if batch:
266
+ inserted = db.insert_rows(table_name, batch)
267
+ if inserted is not None:
268
+ total_inserted += inserted
269
+
270
+ logger.info(f"Rank {rank} inserted {total_inserted} rows")
271
+ return total_inserted
272
+
273
+
274
+ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool:
275
+ """Main method to import data into database"""
276
+ # 1. Pre-scan to get rank tasks
277
+ monitor_db.init_schema()
278
+ rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers)
279
+ if not rank_tasks:
280
+ logger.error("No valid data files found during pre-scan")
281
+ return False
282
+
283
+ # 2. Get metric and target mappings
284
+ try:
285
+ metric_id_dict = monitor_db.get_metric_mapping()
286
+ target_dict = monitor_db.get_target_mapping()
287
+ except Exception as e:
288
+ logger.error(f"Failed to get database mappings: {str(e)}")
289
+ return False
290
+
291
+ # 3. Process data for each rank in parallel
292
+ total_files = sum(len(files) for files in rank_tasks.values())
293
+ logger.info(f"Starting data import for {len(rank_tasks)} ranks,"
294
+ f"{total_files} files..."
295
+ )
296
+ all_succeeded = True
297
+ with ProcessPoolExecutor(max_workers=workers) as executor:
298
+ futures = {
299
+ executor.submit(
300
+ process_single_rank,
301
+ (rank, files),
302
+ metric_id_dict,
303
+ target_dict,
304
+ monitor_db.step_partition_size,
305
+ monitor_db.db_path): rank
306
+ for rank, files in rank_tasks.items()
307
+ }
308
+
309
+ with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar:
310
+ for future in pbar:
311
+ rank = futures[future]
312
+ try:
313
+ inserted = future.result()
314
+ pbar.set_postfix_str(
315
+ f"Rank {rank}: inserted {inserted} rows")
316
+ except Exception as e:
317
+ logger.error(
318
+ f"Failed to process Rank {rank}: {str(e)}")
319
+ all_succeeded = False
320
+ return all_succeeded
321
+
322
+
323
+ def csv2db(config: CSV2DBConfig) -> None:
324
+ """Main function to convert CSV files to database"""
325
+ validate_process_num(config.process_num)
326
+ validate_step_partition(config.step_partition)
327
+ validate_data_type_list(config.data_type_list)
328
+
329
+ target_output_dirs = get_target_output_dir(
330
+ config.monitor_path, config.time_start, config.time_end)
331
+
332
+ if config.output_dirpath is None:
333
+ local_tz = pytz.timezone("Asia/Shanghai")
334
+ cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S")
335
+ config.output_dirpath = os.path.join(
336
+ config.monitor_path, f"{cur_time}-csv2db")
337
+
338
+ create_directory(config.output_dirpath)
339
+ db_path = os.path.join(config.output_dirpath, "monitor_metrics.db")
340
+
341
+ if os.path.exists(db_path):
342
+ remove_path(db_path)
343
+ logger.warning(f"Existing path {db_path} will be recovered")
344
+
345
+ db = MonitorDB(db_path, step_partition_size=config.step_partition)
346
+
347
+ result = import_data(
348
+ db,
349
+ target_output_dirs,
350
+ config.data_type_list if config.data_type_list else all_data_type_list,
351
+ workers=config.process_num
352
+ )
353
+ recursive_chmod(config.output_dirpath)
354
+ if result:
355
+ logger.info(
356
+ f"Data import completed. Output saved to: {config.output_dirpath}")
357
+ else:
358
+ logger.warning(
359
+ f"Data import may be incomplete. Output directory: {config.output_dirpath} "
360
+ f"(Some records might have failed)"
361
+ )
@@ -0,0 +1,278 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from collections import OrderedDict
16
+ from collections.abc import Iterable
17
+ from typing import Dict, List, Optional, Set, Tuple
18
+
19
+ from msprobe.core.common.const import MonitorConst
20
+ from msprobe.core.common.db_manager import DBManager
21
+
22
+
23
+ def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict:
24
+ """Update ordered dictionary with new items"""
25
+ for item in new_list:
26
+ if item not in main_dict:
27
+ main_dict[item] = None
28
+ return main_dict
29
+
30
+
31
+ def get_ordered_stats(stats: Iterable) -> List[str]:
32
+ """Get statistics in predefined order"""
33
+ if not isinstance(stats, Iterable):
34
+ return []
35
+ return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats]
36
+
37
+
38
+ class MonitorSql:
39
+ """数据库表参数类"""
40
+
41
+ @staticmethod
42
+ def create_monitoring_targets_table():
43
+ """监控目标表"""
44
+ return """
45
+ CREATE TABLE IF NOT EXISTS monitoring_targets (
46
+ target_id INTEGER PRIMARY KEY AUTOINCREMENT,
47
+ target_name TEXT NOT NULL,
48
+ vpp_stage INTEGER NOT NULL,
49
+ micro_step INTEGER NOT NULL DEFAULT 0,
50
+ UNIQUE(target_name, vpp_stage, micro_step)
51
+ )"""
52
+
53
+ @staticmethod
54
+ def create_monitoring_metrics_table():
55
+ """监控指标表"""
56
+ return """
57
+ CREATE TABLE IF NOT EXISTS monitoring_metrics (
58
+ metric_id INTEGER PRIMARY KEY AUTOINCREMENT,
59
+ metric_name TEXT UNIQUE NOT NULL
60
+ )"""
61
+
62
+ @staticmethod
63
+ def get_metric_mapping_sql():
64
+ return """
65
+ SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats
66
+ FROM monitoring_metrics m
67
+ LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id
68
+ GROUP BY m.metric_id
69
+ """
70
+
71
+ @staticmethod
72
+ def create_metric_stats_table():
73
+ """指标统计表"""
74
+ return """
75
+ CREATE TABLE IF NOT EXISTS metric_stats (
76
+ metric_id INTEGER NOT NULL,
77
+ stat_name TEXT NOT NULL,
78
+ PRIMARY KEY (metric_id, stat_name),
79
+ FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id)
80
+ ) WITHOUT ROWID"""
81
+
82
+ @staticmethod
83
+ def create_global_stat_table():
84
+ return """
85
+ CREATE TABLE IF NOT EXISTS global_stats (
86
+ stat_name TEXT PRIMARY KEY,
87
+ stat_value INTEGER NOT NULL
88
+ ) WITHOUT ROWID"""
89
+
90
+ @classmethod
91
+ def get_table_definition(cls, table_name=""):
92
+ """
93
+ 获取表定义SQL
94
+ :param table_name: 表名
95
+ :return: 建表SQL语句
96
+ :raises ValueError: 当表名不存在时
97
+ """
98
+ table_creators = {
99
+ "monitoring_targets": cls.create_monitoring_targets_table,
100
+ "monitoring_metrics": cls.create_monitoring_metrics_table,
101
+ "metric_stats": cls.create_metric_stats_table,
102
+ "global_stats": cls.create_global_stat_table,
103
+ }
104
+ if not table_name:
105
+ return [table_creators.get(table, lambda x: "")() for table in table_creators]
106
+ if table_name not in table_creators:
107
+ raise ValueError(f"Unsupported table name: {table_name}")
108
+ return table_creators[table_name]()
109
+
110
+ @classmethod
111
+ def get_metric_table_definition(cls, table_name, stats, patition=None):
112
+ stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats]
113
+ if patition and len(patition) == 2:
114
+ partition_start_step, partition_end_step = patition
115
+ step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step}
116
+ AND {partition_end_step}),"""
117
+ else:
118
+ step_column = "step INTEGER NOT NULL"
119
+ create_sql = f"""
120
+ CREATE TABLE {table_name} (
121
+ rank INTEGER NOT NULL,
122
+ {step_column}
123
+ target_id INTEGER NOT NULL,
124
+ {', '.join(stat_columns)},
125
+ PRIMARY KEY (rank, step, target_id),
126
+ FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id)
127
+ ) WITHOUT ROWID
128
+ """
129
+ return create_sql
130
+
131
+
132
+ class MonitorDB:
133
+ """Main class for monitoring database operations"""
134
+
135
+ def __init__(self, db_path: str, step_partition_size: int = 500):
136
+ self.db_path = db_path
137
+ self.db_manager = DBManager(db_path)
138
+ self.step_partition_size = step_partition_size
139
+
140
+ def get_metric_table_name(self, metric_id: int, step: int) -> str:
141
+ """Generate metric table name"""
142
+ step_start = (
143
+ step // self.step_partition_size) * self.step_partition_size
144
+ step_end = step_start + self.step_partition_size - 1
145
+ return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end
146
+
147
+ def init_schema(self) -> None:
148
+ """Initialize database schema"""
149
+ self.db_manager.execute_multi_sql(MonitorSql.get_table_definition())
150
+
151
+ # Insert initial global stats
152
+ global_stats = [
153
+ ('max_rank', 0),
154
+ ('min_step', 0),
155
+ ('max_step', 0),
156
+ ('step_partition_size', self.step_partition_size)
157
+ ]
158
+ self.db_manager.insert_data("global_stats", global_stats)
159
+
160
+ def insert_dimensions(
161
+ self,
162
+ targets: OrderedDict,
163
+ metrics: Set[str],
164
+ metric_stats: Dict[str, Set[str]],
165
+ min_step: Optional[int] = None,
166
+ max_step: int = None,
167
+ ) -> None:
168
+ """Insert dimension data into database"""
169
+ # Insert targets
170
+ self.db_manager.insert_data(
171
+ "monitoring_targets",
172
+ [(name, vpp_stage, micro_step)
173
+ for (name, vpp_stage, micro_step) in targets],
174
+ key_list=["target_name", "vpp_stage", "micro_step"]
175
+ )
176
+
177
+ # Insert metrics
178
+ self.db_manager.insert_data(
179
+ "monitoring_metrics",
180
+ [(metric,) for metric in metrics],
181
+ key_list=["metric_name"]
182
+ )
183
+
184
+ # Insert metric-stat relationships
185
+ for metric, stats in metric_stats.items():
186
+ metric_id = self._get_metric_id(metric)
187
+ ordered_stats = get_ordered_stats(stats)
188
+
189
+ self.db_manager.insert_data(
190
+ "metric_stats",
191
+ [(metric_id, stat) for stat in ordered_stats],
192
+ key_list=["metric_id", "stat_name"]
193
+ )
194
+
195
+ # Create metric tables for each partition
196
+ if min_step is not None and max_step is not None:
197
+ first_partition = min_step // self.step_partition_size
198
+ last_partition = max_step // self.step_partition_size
199
+
200
+ for partition in range(first_partition, last_partition + 1):
201
+ step_start = partition * self.step_partition_size
202
+ self.create_metric_table(
203
+ metric_id, step_start, ordered_stats)
204
+
205
+ def insert_rows(self, table_name, rows):
206
+ if not self.db_manager.table_exists(table_name):
207
+ raise RuntimeError(f"{table_name} not existed in {self.db_path}")
208
+ inserted = self.db_manager.insert_data(table_name, rows)
209
+ inserted = 0 if inserted is None else inserted
210
+ return inserted
211
+
212
+ def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str:
213
+ """Create metric table for a specific partition"""
214
+ table_name, partition_start_step, partition_end_step = self.get_metric_table_name(
215
+ metric_id,
216
+ step
217
+ )
218
+ if self.db_manager.table_exists(table_name):
219
+ return table_name
220
+
221
+ create_sql = MonitorSql.get_metric_table_definition(
222
+ table_name, stats, patition=(
223
+ partition_start_step, partition_end_step)
224
+ )
225
+ self.db_manager.execute_sql(create_sql)
226
+ return table_name
227
+
228
+ def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None:
229
+ """Update global statistics"""
230
+ updates = [
231
+ ("max_rank", max_rank),
232
+ ("min_step", min_step),
233
+ ("max_step", max_step)
234
+ ]
235
+ for stat_name, value in updates:
236
+ if not value:
237
+ continue
238
+ self.db_manager.update_data(
239
+ table_name="global_stats",
240
+ updates={"stat_value": value},
241
+ where={"stat_name": stat_name}
242
+ )
243
+
244
+ def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]:
245
+ """Get metric name to ID mapping with statistics"""
246
+ results = self.db_manager.execute_sql(
247
+ MonitorSql.get_metric_mapping_sql()
248
+ )
249
+
250
+ return {
251
+ row["metric_name"]: (
252
+ row["metric_id"],
253
+ get_ordered_stats(row["stats"].split(",")
254
+ ) if row["stats"] else []
255
+ ) for row in results
256
+ }
257
+
258
+ def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]:
259
+ """Get target mapping dictionary"""
260
+ results = self.db_manager.select_data(
261
+ table_name="monitoring_targets",
262
+ columns=["target_id", "target_name", "vpp_stage", "micro_step"]
263
+ )
264
+ if not results:
265
+ return {}
266
+ return {
267
+ (row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"]
268
+ for row in results
269
+ }
270
+
271
+ def _get_metric_id(self, metric_name: str) -> Optional[int]:
272
+ """Get metric ID by name"""
273
+ result = self.db_manager.select_data(
274
+ table_name="monitoring_metrics",
275
+ columns=["metric_id"],
276
+ where={"metric_name": metric_name}
277
+ )
278
+ return result[0]["metric_id"] if result else None
@@ -96,8 +96,33 @@ def validate_targets(targets):
96
96
  raise TypeError('key of targets should be module_name[str] in config.json')
97
97
  if not isinstance(field, dict):
98
98
  raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
99
+
99
100
 
100
-
101
+ def validate_l2_targets(targets):
102
+ if not isinstance(targets, dict):
103
+ raise TypeError('l2_targets in config.json should be a dict')
104
+ for hook_name, target_list in targets.items():
105
+ if hook_name not in MonitorConst.L2_HOOKS:
106
+ raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}')
107
+ if not isinstance(target_list, list):
108
+ raise TypeError('values of l2_targets should be a list in config.json')
109
+ for item in target_list:
110
+ if not isinstance(item, str):
111
+ raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json')
112
+
113
+
114
+ def validate_recording_l2_features(recording_l2_features):
115
+ if not isinstance(recording_l2_features, bool):
116
+ raise TypeError("recording_l2_features should be a bool")
117
+
118
+
119
+ def validate_sa_order(sa_order):
120
+ if isinstance(sa_order, str):
121
+ sa_order = sa_order.replace(' ', '')
122
+ if sa_order not in MonitorConst.SA_ORDERS:
123
+ raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}')
124
+
125
+
101
126
  def validate_print_struct(print_struct):
102
127
  if not isinstance(print_struct, bool):
103
128
  raise TypeError("print_struct should be a bool")
@@ -216,6 +241,15 @@ def validate_config(config):
216
241
  targets = config.get("targets", {})
217
242
  validate_targets(targets)
218
243
 
244
+ l2_targets = config.get("l2_targets", {})
245
+ validate_l2_targets(l2_targets)
246
+
247
+ recording_l2_features = config.get("recording_l2_features", False)
248
+ validate_recording_l2_features(recording_l2_features)
249
+
250
+ sa_order = config.get("sa_order", "s,b,h,d")
251
+ validate_sa_order(sa_order)
252
+
219
253
  print_struct = config.get('print_struct', False)
220
254
  validate_print_struct(print_struct)
221
255