mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
msprobe/README.md
CHANGED
|
@@ -35,17 +35,17 @@ export MSPROBE_LOG_LEVEL={x}
|
|
|
35
35
|
|
|
36
36
|
## 环境和依赖
|
|
37
37
|
|
|
38
|
-
- 硬件环境请参见《[昇腾产品形态说明](https://
|
|
39
|
-
- 软件环境请参见《[CANN 软件安装指南](https://
|
|
38
|
+
- 硬件环境请参见《[昇腾产品形态说明](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/quickstart/quickstart/quickstart_18_0002.html)》。
|
|
39
|
+
- 软件环境请参见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/softwareinst/instg/instg_0000.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。
|
|
40
40
|
|
|
41
41
|
以上环境依赖请根据实际环境选择适配的版本。
|
|
42
42
|
|
|
43
43
|
## 版本配套说明
|
|
44
44
|
|
|
45
|
-
- msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://
|
|
45
|
+
- msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitcode.com/Ascend/pytorch)》。
|
|
46
46
|
- msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。
|
|
47
47
|
- msprobe支持MSAdapter 2.1.0。
|
|
48
|
-
- msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://
|
|
48
|
+
- msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://www.hiascend.com/hardware/firmware-drivers/community?product=2&model=28&cann=8.0.RC3.alpha003&driver=1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
## 🚨 工具限制与注意事项
|
|
@@ -84,7 +84,7 @@ msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API
|
|
|
84
84
|
|
|
85
85
|
精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
|
|
86
86
|
|
|
87
|
-
PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)
|
|
87
|
+
PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)
|
|
88
88
|
|
|
89
89
|
MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
|
|
90
90
|
|
|
@@ -165,7 +165,7 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
|
|
|
165
165
|
|
|
166
166
|
训练前或精度比对前,对比两个环境下可能影响训练精度的配置差异。
|
|
167
167
|
|
|
168
|
-
[
|
|
168
|
+
[训练前配置检查](./docs/31.config_check.md)
|
|
169
169
|
|
|
170
170
|
训练过程中或结束后,比较两个不同的checkpoint,评估模型相似度。
|
|
171
171
|
|
msprobe/core/common/const.py
CHANGED
|
@@ -24,6 +24,8 @@ class Const:
|
|
|
24
24
|
Class for const
|
|
25
25
|
"""
|
|
26
26
|
TOOL_NAME = "msprobe"
|
|
27
|
+
MD5_INDEX = "md5_index"
|
|
28
|
+
MD5 = "md5"
|
|
27
29
|
|
|
28
30
|
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
29
31
|
SEP = "."
|
|
@@ -52,9 +54,9 @@ class Const:
|
|
|
52
54
|
SIX_SEGMENT = 6
|
|
53
55
|
SEVEN_SEGMENT = 7
|
|
54
56
|
|
|
55
|
-
MAX_DEPTH =
|
|
57
|
+
MAX_DEPTH = 400
|
|
56
58
|
CPU_QUARTER = 4
|
|
57
|
-
DUMP_MAX_DEPTH =
|
|
59
|
+
DUMP_MAX_DEPTH = 400
|
|
58
60
|
|
|
59
61
|
EXTERN_INPUT_LIST_MAX_LEN = 100
|
|
60
62
|
MAX_PROCESS_NUM = 128
|
|
@@ -72,6 +74,7 @@ class Const:
|
|
|
72
74
|
ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
|
|
73
75
|
SUMMARY = "summary"
|
|
74
76
|
MD5 = "md5"
|
|
77
|
+
HASH = "hash"
|
|
75
78
|
VALUE = "value"
|
|
76
79
|
SUMMARY_MODE = ["statistics", "md5"]
|
|
77
80
|
|
|
@@ -113,9 +116,13 @@ class Const:
|
|
|
113
116
|
RUN_UT = "run_ut"
|
|
114
117
|
GRAD_PROBE = "grad_probe"
|
|
115
118
|
STRUCTURE = "structure"
|
|
116
|
-
|
|
119
|
+
EXCEPTION_DUMP = "exception_dump"
|
|
120
|
+
DUMP_PRECISION_HIGH = "high"
|
|
121
|
+
DUMP_PRECISION_LOW = "low"
|
|
122
|
+
TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE, EXCEPTION_DUMP]
|
|
117
123
|
DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE]
|
|
118
124
|
DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
|
|
125
|
+
DUMP_PRECISION_LIST = [DUMP_PRECISION_LOW, DUMP_PRECISION_HIGH]
|
|
119
126
|
LEVEL_L0 = "L0"
|
|
120
127
|
LEVEL_L1 = "L1"
|
|
121
128
|
LEVEL_L2 = "L2"
|
|
@@ -237,7 +244,11 @@ class Const:
|
|
|
237
244
|
MEAN = 'Mean'
|
|
238
245
|
NORM = 'Norm'
|
|
239
246
|
DATA_NAME = 'data_name'
|
|
247
|
+
STATE = 'state'
|
|
248
|
+
REQ_GRAD = 'requires_grad'
|
|
249
|
+
API_ORIGIN_NAME = 'api_origin_name'
|
|
240
250
|
TENSOR_STAT_INDEX = 'tensor_stat_index'
|
|
251
|
+
SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM]
|
|
241
252
|
|
|
242
253
|
CODE_STACK = 'Code Stack'
|
|
243
254
|
OP_NAME = 'Op Name'
|
|
@@ -260,8 +271,15 @@ class Const:
|
|
|
260
271
|
|
|
261
272
|
TENSOR_STAT_LEN = 2
|
|
262
273
|
|
|
274
|
+
TENSOR_TYPE = "torch.Tensor"
|
|
275
|
+
DTENSOR_TYPE = "torch.distributed.tensor.DTensor"
|
|
276
|
+
FAKE_TENSOR_TYPE = "torch._subclasses.fake_tensor.FakeTensor"
|
|
277
|
+
AC_TENSOR_TYPE = "torch.distributed._functional_collectives.AsyncCollectiveTensor"
|
|
278
|
+
|
|
263
279
|
SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml"
|
|
264
280
|
|
|
281
|
+
API_ATTR_LIST = ["__name__", "default"]
|
|
282
|
+
|
|
265
283
|
PT_API_TYPE_FUNCTIONAL = "functional"
|
|
266
284
|
PT_API_TYPE_TENSOR = "tensor"
|
|
267
285
|
PT_API_TYPE_TORCH = "torch"
|
|
@@ -355,22 +373,22 @@ class Const:
|
|
|
355
373
|
}
|
|
356
374
|
|
|
357
375
|
def _fused_adamw_(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
376
|
+
self,
|
|
377
|
+
grads,
|
|
378
|
+
exp_avgs,
|
|
379
|
+
exp_avg_sqs,
|
|
380
|
+
max_exp_avg_sqs,
|
|
381
|
+
state_steps,
|
|
382
|
+
*,
|
|
383
|
+
lr,
|
|
384
|
+
beta1,
|
|
385
|
+
beta2,
|
|
386
|
+
weight_decay,
|
|
387
|
+
eps,
|
|
388
|
+
amsgrad,
|
|
389
|
+
maximize,
|
|
390
|
+
grad_scale=None,
|
|
391
|
+
found_inf=None
|
|
374
392
|
):
|
|
375
393
|
pass
|
|
376
394
|
|
|
@@ -382,6 +400,13 @@ class Const:
|
|
|
382
400
|
MATCH_MODE_NAME = "pure name"
|
|
383
401
|
MATCH_MODE_MAPPING = "mapping"
|
|
384
402
|
MATCH_MODE_SIMILARITY = "similarity"
|
|
403
|
+
CONFIG_CHECK_PASS = "pass"
|
|
404
|
+
CONFIG_CHECK_WARNING = "warning"
|
|
405
|
+
CONFIG_CHECK_ERROR = "error"
|
|
406
|
+
|
|
407
|
+
MIX_DUMP_NAMES = {'graph', 'pynative'}
|
|
408
|
+
|
|
409
|
+
MEGATRON_MICRO_STEP_NUMBER = 'megatron_micro_step_number'
|
|
385
410
|
|
|
386
411
|
|
|
387
412
|
class CompareConst:
|
|
@@ -397,10 +422,14 @@ class CompareConst:
|
|
|
397
422
|
BENCH_DTYPE = "Bench Dtype"
|
|
398
423
|
NPU_SHAPE = "NPU Tensor Shape"
|
|
399
424
|
BENCH_SHAPE = "Bench Tensor Shape"
|
|
425
|
+
NPU_CSV_FILE = "NPU CSV File"
|
|
426
|
+
BENCH_CSV_FILE = "Bench CSV File"
|
|
400
427
|
NPU_MAX = "NPU max"
|
|
401
428
|
NPU_MIN = "NPU min"
|
|
402
429
|
NPU_MEAN = "NPU mean"
|
|
403
430
|
NPU_NORM = "NPU l2norm"
|
|
431
|
+
NPU_P2POP_PEER = "NPU P2POp peer"
|
|
432
|
+
|
|
404
433
|
BENCH_MAX = "Bench max"
|
|
405
434
|
BENCH_MIN = "Bench min"
|
|
406
435
|
BENCH_MEAN = "Bench mean"
|
|
@@ -416,6 +445,9 @@ class CompareConst:
|
|
|
416
445
|
MIN_RELATIVE_ERR = "MinRelativeErr"
|
|
417
446
|
MEAN_RELATIVE_ERR = "MeanRelativeErr"
|
|
418
447
|
NORM_RELATIVE_ERR = "NormRelativeErr"
|
|
448
|
+
REQ_GRAD_CONSIST = "Requires_grad Consistent"
|
|
449
|
+
NPU_REQ_GRAD = "NPU Requires_grad"
|
|
450
|
+
BENCH_REQ_GRAD = "Bench Requires_grad"
|
|
419
451
|
ACCURACY = "Accuracy Reached or Not"
|
|
420
452
|
STACK = "NPU_Stack_Info"
|
|
421
453
|
DATA_NAME = "Data_name"
|
|
@@ -437,7 +469,7 @@ class CompareConst:
|
|
|
437
469
|
SUMMARY = "summary"
|
|
438
470
|
COMPARE_RESULT = "compare_result"
|
|
439
471
|
COMPARE_MESSAGE = "compare_message"
|
|
440
|
-
MAX_EXCEL_LENGTH =
|
|
472
|
+
MAX_EXCEL_LENGTH = 1048500
|
|
441
473
|
YES = "Yes"
|
|
442
474
|
NO = "No"
|
|
443
475
|
STATISTICS_INDICATOR_NUM = 4
|
|
@@ -485,21 +517,21 @@ class CompareConst:
|
|
|
485
517
|
|
|
486
518
|
ULP_ERR_STATUS = "ulp_err_status"
|
|
487
519
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
]
|
|
520
|
+
ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR,
|
|
521
|
+
ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
|
|
522
|
+
SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
|
|
523
|
+
MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
|
|
524
|
+
MD5_COMPARE_INDEX = [RESULT]
|
|
493
525
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR,
|
|
497
|
-
NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE
|
|
498
|
-
]
|
|
526
|
+
BASIC_INFO = [NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_REQ_GRAD, BENCH_REQ_GRAD]
|
|
527
|
+
SUMMARY_INFO = [NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM]
|
|
499
528
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
529
|
+
COMPARE_RESULT_HEADER = BASIC_INFO + ALL_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, ACCURACY, ERROR_MESSAGE]
|
|
530
|
+
|
|
531
|
+
SUMMARY_COMPARE_RESULT_HEADER = BASIC_INFO + SUMMARY_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, RESULT,
|
|
532
|
+
ERROR_MESSAGE]
|
|
533
|
+
|
|
534
|
+
MD5_COMPARE_RESULT_HEADER = BASIC_INFO + [NPU_MD5, BENCH_MD5, REQ_GRAD_CONSIST] + MD5_COMPARE_INDEX
|
|
503
535
|
|
|
504
536
|
COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK]
|
|
505
537
|
|
|
@@ -513,11 +545,6 @@ class CompareConst:
|
|
|
513
545
|
Const.MD5: MD5_COMPARE_RESULT_HEADER
|
|
514
546
|
}
|
|
515
547
|
|
|
516
|
-
ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
|
|
517
|
-
FIVE_THOUSANDTHS_ERR_RATIO]
|
|
518
|
-
SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
|
|
519
|
-
MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
|
|
520
|
-
|
|
521
548
|
# dtype match
|
|
522
549
|
|
|
523
550
|
DTYPE_MATCH_GROUPS = [
|
|
@@ -554,6 +581,8 @@ class CompareConst:
|
|
|
554
581
|
ULP_FLOAT16_THRESHOLD = 1
|
|
555
582
|
|
|
556
583
|
# compare result data
|
|
584
|
+
NO_REAL_DATA = 'No real data'
|
|
585
|
+
API_UNMATCH = 'api unmatched'
|
|
557
586
|
READ_NONE = 'No data'
|
|
558
587
|
NONE = 'None'
|
|
559
588
|
SHAPE_UNMATCH = 'shape unmatched'
|
|
@@ -622,6 +651,9 @@ class CompareConst:
|
|
|
622
651
|
MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
|
|
623
652
|
MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
|
|
624
653
|
}
|
|
654
|
+
MS_GRAPH_CSV = {
|
|
655
|
+
NPU_CSV_FILE: None, BENCH_CSV_FILE: None
|
|
656
|
+
}
|
|
625
657
|
|
|
626
658
|
API_MAPPING_KEYS_TO_COMPARE = [
|
|
627
659
|
('ms_args', 'pt_args'),
|
|
@@ -641,9 +673,11 @@ class CompareConst:
|
|
|
641
673
|
|
|
642
674
|
OP_NAME_X = 'op_name_x'
|
|
643
675
|
MATCH_RESULT_COLUMNS = [
|
|
644
|
-
OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', '
|
|
676
|
+
OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'state_x', 'api_origin_name_x',
|
|
677
|
+
'requires_grad_x', 'data_name_x',
|
|
645
678
|
CMP_KEY, CMP_SHAPE,
|
|
646
|
-
'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', '
|
|
679
|
+
'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'state_y', 'api_origin_name_y',
|
|
680
|
+
'requires_grad_y', 'data_name_y'
|
|
647
681
|
]
|
|
648
682
|
|
|
649
683
|
INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
|
|
@@ -674,6 +708,8 @@ class FileCheckConst:
|
|
|
674
708
|
IR_SUFFIX = ".ir"
|
|
675
709
|
ZIP_SUFFIX = ".zip"
|
|
676
710
|
SHELL_SUFFIX = ".sh"
|
|
711
|
+
LOG_SUFFIX = ".log"
|
|
712
|
+
DB_SUFFIX = '.db'
|
|
677
713
|
MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
678
714
|
MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
679
715
|
MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
@@ -686,6 +722,8 @@ class FileCheckConst:
|
|
|
686
722
|
MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
687
723
|
MAX_FILE_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
688
724
|
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
725
|
+
MAX_LOG_SIZE = 10737418240 # 1 * 1024 * 1024 * 1024
|
|
726
|
+
MAX_DB_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
689
727
|
DIR = "dir"
|
|
690
728
|
FILE = "file"
|
|
691
729
|
DATA_DIR_AUTHORITY = 0o750
|
|
@@ -699,7 +737,9 @@ class FileCheckConst:
|
|
|
699
737
|
XLSX_SUFFIX: MAX_XLSX_SIZE,
|
|
700
738
|
YAML_SUFFIX: MAX_YAML_SIZE,
|
|
701
739
|
IR_SUFFIX: MAX_IR_SIZE,
|
|
702
|
-
ZIP_SUFFIX: MAX_ZIP_SIZE
|
|
740
|
+
ZIP_SUFFIX: MAX_ZIP_SIZE,
|
|
741
|
+
LOG_SUFFIX: MAX_LOG_SIZE,
|
|
742
|
+
DB_SUFFIX: MAX_DB_SIZE
|
|
703
743
|
}
|
|
704
744
|
CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
|
|
705
745
|
|
|
@@ -758,6 +798,11 @@ class MonitorConst:
|
|
|
758
798
|
DEFAULT_STEP_INTERVAL = 1
|
|
759
799
|
|
|
760
800
|
OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"]
|
|
801
|
+
OP_MONVIS_SUPPORTED = [
|
|
802
|
+
"norm", "min", "max", "zeros", "nans", "mean",
|
|
803
|
+
"entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian",
|
|
804
|
+
"proxy", "token_similarity"
|
|
805
|
+
]
|
|
761
806
|
MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
|
|
762
807
|
DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
|
|
763
808
|
DATABASE = "database"
|
|
@@ -770,6 +815,8 @@ class MonitorConst:
|
|
|
770
815
|
)
|
|
771
816
|
DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
|
|
772
817
|
RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan']
|
|
818
|
+
L2_HOOKS = ["linear_hook", "attention_hook"]
|
|
819
|
+
SA_ORDERS = ["s,b,h,d", "b,s,h,d"]
|
|
773
820
|
|
|
774
821
|
SLICE_SIZE = 20480
|
|
775
822
|
# used for name
|
|
@@ -781,6 +828,7 @@ class MonitorConst:
|
|
|
781
828
|
ACTV_OUT = "output"
|
|
782
829
|
ACTVGRAD_IN = "input_grad"
|
|
783
830
|
ACTVGRAD_OUT = "output_grad"
|
|
831
|
+
FSDP_FLAT_SEP = "_fsdp_wrapped_module."
|
|
784
832
|
# used for tasks
|
|
785
833
|
ACTV = "actv"
|
|
786
834
|
ACTVGRAD = "actv_grad"
|
|
@@ -820,3 +868,12 @@ class MonitorConst:
|
|
|
820
868
|
TRAIN_STAGE[key] = BACKWARD_STAGE
|
|
821
869
|
for key in OPTIMIZER_KEY:
|
|
822
870
|
TRAIN_STAGE[key] = OPTIMIZER_STAGE
|
|
871
|
+
|
|
872
|
+
# csv2db
|
|
873
|
+
DEFAULT_INT_VALUE = 0
|
|
874
|
+
MAX_PROCESS_NUM = 128
|
|
875
|
+
CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv"
|
|
876
|
+
BATCH_SIZE = 10000
|
|
877
|
+
MAX_PARTITION = 10_000_000
|
|
878
|
+
MIN_PARTITION = 10
|
|
879
|
+
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import re
|
|
16
|
+
import sqlite3
|
|
17
|
+
from typing import List, Tuple, Dict, Any
|
|
18
|
+
from functools import wraps
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
from msprobe.core.common.file_utils import check_path_before_create, change_mode
|
|
22
|
+
from msprobe.core.common.const import FileCheckConst
|
|
23
|
+
|
|
24
|
+
SAFE_SQL_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def check_identifier_safety(name):
|
|
28
|
+
"""验证标识符是否安全(防止SQL注入)"""
|
|
29
|
+
if not isinstance(name, str) or SAFE_SQL_PATTERN.match(name) is None:
|
|
30
|
+
raise ValueError(f"Invalid SQL identifier: {name}, potential SQL injection risk!")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _db_operation(func):
|
|
34
|
+
"""数据库操作装饰器,自动管理连接"""
|
|
35
|
+
@wraps(func)
|
|
36
|
+
def wrapper(self, *args, **kwargs):
|
|
37
|
+
conn, curs = None, None
|
|
38
|
+
try:
|
|
39
|
+
conn, curs = self._get_connection()
|
|
40
|
+
result = func(self, conn, curs, *args, **kwargs)
|
|
41
|
+
return result # 显式返回正常结果
|
|
42
|
+
|
|
43
|
+
except sqlite3.Error as err:
|
|
44
|
+
logger.error(f"Database operation failed: {err}")
|
|
45
|
+
if conn:
|
|
46
|
+
conn.rollback()
|
|
47
|
+
return None # 显式返回错误情况下的None
|
|
48
|
+
|
|
49
|
+
finally:
|
|
50
|
+
self._release_connection(conn, curs)
|
|
51
|
+
return wrapper
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DBManager:
|
|
55
|
+
"""
|
|
56
|
+
数据库管理类,封装常用数据库操作
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
DEFAULT_FETCH_SIZE = 10000
|
|
60
|
+
DEFAULT_INSERT_SIZE = 10000
|
|
61
|
+
MAX_ROW_COUNT = 100000000
|
|
62
|
+
|
|
63
|
+
def __init__(self, db_path: str):
|
|
64
|
+
"""
|
|
65
|
+
初始化DBManager
|
|
66
|
+
:param db_path: 数据库文件路径
|
|
67
|
+
:param table_config: 表配置对象
|
|
68
|
+
"""
|
|
69
|
+
self.db_path = db_path
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def _get_where_sql(where_list):
|
|
73
|
+
if not where_list:
|
|
74
|
+
return "", tuple()
|
|
75
|
+
|
|
76
|
+
where_clauses = []
|
|
77
|
+
where_values = []
|
|
78
|
+
if where_list:
|
|
79
|
+
for col, val in where_list.items():
|
|
80
|
+
check_identifier_safety(col)
|
|
81
|
+
where_clauses.append(f"{col} = ?")
|
|
82
|
+
where_values.append(val)
|
|
83
|
+
if where_clauses:
|
|
84
|
+
where_sql = " WHERE " + " AND ".join(where_clauses)
|
|
85
|
+
return where_sql, tuple(where_values)
|
|
86
|
+
|
|
87
|
+
@_db_operation
|
|
88
|
+
def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
|
|
89
|
+
table_name: str, data: List[Tuple], key_list: List[str] = None) -> int:
|
|
90
|
+
"""
|
|
91
|
+
批量插入数据
|
|
92
|
+
:param table_name: 表名
|
|
93
|
+
:param data: 要插入的数据列表
|
|
94
|
+
:param batch_size: 每批插入的大小
|
|
95
|
+
:return: 插入的行数
|
|
96
|
+
"""
|
|
97
|
+
check_identifier_safety(table_name)
|
|
98
|
+
|
|
99
|
+
if not data:
|
|
100
|
+
return 0
|
|
101
|
+
columns = len(data[0])
|
|
102
|
+
if key_list:
|
|
103
|
+
if not isinstance(key_list, list):
|
|
104
|
+
raise TypeError(
|
|
105
|
+
f"key_list must be a list, got {type(key_list)}"
|
|
106
|
+
)
|
|
107
|
+
if columns != len(key_list):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"When inserting into table {table_name}, the length of key list ({key_list})"
|
|
110
|
+
f"does not match the data({columns}).")
|
|
111
|
+
for key in key_list:
|
|
112
|
+
check_identifier_safety(key)
|
|
113
|
+
|
|
114
|
+
batch_size = self.DEFAULT_INSERT_SIZE
|
|
115
|
+
placeholders = ", ".join(["?"] * columns)
|
|
116
|
+
if key_list:
|
|
117
|
+
keys = ", ".join(key_list)
|
|
118
|
+
sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})"
|
|
119
|
+
else:
|
|
120
|
+
sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})"
|
|
121
|
+
|
|
122
|
+
inserted_rows = 0
|
|
123
|
+
for i in range(0, len(data), batch_size):
|
|
124
|
+
batch = data[i:i + batch_size]
|
|
125
|
+
curs.executemany(sql, batch)
|
|
126
|
+
inserted_rows += curs.rowcount
|
|
127
|
+
|
|
128
|
+
conn.commit()
|
|
129
|
+
return inserted_rows
|
|
130
|
+
|
|
131
|
+
@_db_operation
|
|
132
|
+
def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
|
|
133
|
+
table_name: str,
|
|
134
|
+
columns: List[str] = None,
|
|
135
|
+
where: dict = None) -> List[Dict]:
|
|
136
|
+
"""
|
|
137
|
+
查询数据
|
|
138
|
+
:param table_name: 表名
|
|
139
|
+
:param columns: 要查询的列
|
|
140
|
+
:param where: WHERE条件
|
|
141
|
+
:return: 查询结果列表(字典形式)
|
|
142
|
+
"""
|
|
143
|
+
check_identifier_safety(table_name)
|
|
144
|
+
|
|
145
|
+
if not columns:
|
|
146
|
+
raise ValueError("columns parameter cannot be empty, specify columns to select (e.g. ['id', 'name'])")
|
|
147
|
+
if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns):
|
|
148
|
+
raise TypeError("columns must be a list of strings (e.g. ['id', 'name'])")
|
|
149
|
+
|
|
150
|
+
for col in columns:
|
|
151
|
+
check_identifier_safety(col)
|
|
152
|
+
|
|
153
|
+
cols = ", ".join(columns)
|
|
154
|
+
sql = f"SELECT {cols} FROM {table_name}"
|
|
155
|
+
|
|
156
|
+
where_sql, where_parems = self._get_where_sql(where)
|
|
157
|
+
curs.execute(sql + where_sql, where_parems)
|
|
158
|
+
|
|
159
|
+
return [dict(row) for row in curs.fetchall()]
|
|
160
|
+
|
|
161
|
+
@_db_operation
|
|
162
|
+
def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
|
|
163
|
+
table_name: str, updates: Dict[str, Any],
|
|
164
|
+
where: dict = None) -> int:
|
|
165
|
+
"""
|
|
166
|
+
更新数据
|
|
167
|
+
:param table_name: 表名
|
|
168
|
+
:param updates: 要更新的字段和值
|
|
169
|
+
:param where: WHERE条件
|
|
170
|
+
:param where_params: WHERE条件参数
|
|
171
|
+
:return: 影响的行数
|
|
172
|
+
"""
|
|
173
|
+
check_identifier_safety(table_name)
|
|
174
|
+
if not updates:
|
|
175
|
+
raise ValueError("columns parameter cannot be empty, specify it to update (e.g. {'name': 'xxx'}")
|
|
176
|
+
if not isinstance(updates, dict):
|
|
177
|
+
raise TypeError(f"updates must be a dictionary, got: {type(updates)}")
|
|
178
|
+
for key in updates.keys():
|
|
179
|
+
check_identifier_safety(key)
|
|
180
|
+
|
|
181
|
+
set_clause = ", ".join([f"{k} = ?" for k in updates.keys()])
|
|
182
|
+
sql = f"UPDATE {table_name} SET {set_clause}"
|
|
183
|
+
|
|
184
|
+
params = tuple(updates.values())
|
|
185
|
+
|
|
186
|
+
where_sql, where_parems = self._get_where_sql(where)
|
|
187
|
+
|
|
188
|
+
curs.execute(sql + where_sql, params + where_parems)
|
|
189
|
+
conn.commit()
|
|
190
|
+
return curs.rowcount
|
|
191
|
+
|
|
192
|
+
@_db_operation
|
|
193
|
+
def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
|
|
194
|
+
sql: str, params: Tuple = None) -> List[Dict]:
|
|
195
|
+
"""
|
|
196
|
+
执行自定义SQL查询
|
|
197
|
+
:param sql: SQL语句
|
|
198
|
+
:param params: 参数
|
|
199
|
+
:return: 查询结果
|
|
200
|
+
"""
|
|
201
|
+
curs.execute(sql, params or ())
|
|
202
|
+
if sql.strip().upper().startswith("SELECT"):
|
|
203
|
+
return [dict(row) for row in curs.fetchall()]
|
|
204
|
+
conn.commit()
|
|
205
|
+
return []
|
|
206
|
+
|
|
207
|
+
def table_exists(self, table_name: str) -> bool:
|
|
208
|
+
"""
|
|
209
|
+
:param table_name: 表名
|
|
210
|
+
:return: 查询结果
|
|
211
|
+
"""
|
|
212
|
+
result = self.select_data(
|
|
213
|
+
table_name="sqlite_master",
|
|
214
|
+
columns=["name"],
|
|
215
|
+
where={"type": "table", "name": table_name}
|
|
216
|
+
)
|
|
217
|
+
return len(result) > 0
|
|
218
|
+
|
|
219
|
+
@_db_operation
|
|
220
|
+
def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
|
|
221
|
+
sql_commands: List[str]) -> List[List[Dict]]:
|
|
222
|
+
"""
|
|
223
|
+
批量执行多个SQL语句
|
|
224
|
+
:param sql_commands: [sql1, sql2, ...]
|
|
225
|
+
:return: 每个SELECT语句的结果列表
|
|
226
|
+
"""
|
|
227
|
+
results = []
|
|
228
|
+
for sql in sql_commands:
|
|
229
|
+
curs.execute(sql)
|
|
230
|
+
if sql.strip().upper().startswith("SELECT"):
|
|
231
|
+
results.append([dict(row) for row in curs.fetchall()])
|
|
232
|
+
conn.commit()
|
|
233
|
+
return results
|
|
234
|
+
|
|
235
|
+
def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
|
|
236
|
+
"""获取数据库连接和游标"""
|
|
237
|
+
check_path_before_create(self.db_path)
|
|
238
|
+
try:
|
|
239
|
+
conn = sqlite3.connect(self.db_path)
|
|
240
|
+
conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果
|
|
241
|
+
curs = conn.cursor()
|
|
242
|
+
return conn, curs
|
|
243
|
+
except sqlite3.Error as err:
|
|
244
|
+
logger.error(f"Database connection failed: {err}")
|
|
245
|
+
raise
|
|
246
|
+
|
|
247
|
+
def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None:
|
|
248
|
+
"""释放数据库连接"""
|
|
249
|
+
try:
|
|
250
|
+
if curs is not None:
|
|
251
|
+
curs.close()
|
|
252
|
+
if conn is not None:
|
|
253
|
+
conn.close()
|
|
254
|
+
except sqlite3.Error as err:
|
|
255
|
+
logger.error(f"Failed to release database connection: {err}")
|
|
256
|
+
change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -33,7 +33,7 @@ import pandas as pd
|
|
|
33
33
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
34
34
|
from msprobe.core.common.log import logger
|
|
35
35
|
from msprobe.core.common.exceptions import FileCheckException
|
|
36
|
-
from msprobe.core.common.const import FileCheckConst, CompareConst
|
|
36
|
+
from msprobe.core.common.const import FileCheckConst, CompareConst, Const
|
|
37
37
|
from msprobe.core.common.global_lock import global_lock, is_main_process
|
|
38
38
|
|
|
39
39
|
proc_lock = multiprocessing.Lock()
|
|
@@ -172,7 +172,7 @@ def check_path_exists(path):
|
|
|
172
172
|
if not os.path.exists(path):
|
|
173
173
|
logger.error('The file path %s does not exist.' % path)
|
|
174
174
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
|
|
177
177
|
def check_path_not_exists(path):
|
|
178
178
|
if os.path.exists(path):
|
|
@@ -259,8 +259,8 @@ def check_path_type(file_path, file_type):
|
|
|
259
259
|
def check_others_writable(directory):
|
|
260
260
|
dir_stat = os.stat(directory)
|
|
261
261
|
is_writable = (
|
|
262
|
-
|
|
263
|
-
|
|
262
|
+
bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
|
|
263
|
+
bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
|
|
264
264
|
)
|
|
265
265
|
return is_writable
|
|
266
266
|
|
|
@@ -319,7 +319,7 @@ def check_dirpath_before_read(path):
|
|
|
319
319
|
check_path_owner_consistent(dirpath)
|
|
320
320
|
except FileCheckException:
|
|
321
321
|
logger.warning(f"The directory {dirpath} is not yours.")
|
|
322
|
-
|
|
322
|
+
|
|
323
323
|
|
|
324
324
|
def check_file_or_directory_path(path, isdir=False):
|
|
325
325
|
"""
|
|
@@ -422,6 +422,26 @@ def load_json(json_path):
|
|
|
422
422
|
return data
|
|
423
423
|
|
|
424
424
|
|
|
425
|
+
def load_construct_json(json_path):
|
|
426
|
+
construct_dict_o = load_json(json_path)
|
|
427
|
+
if Const.MEGATRON_MICRO_STEP_NUMBER in construct_dict_o:
|
|
428
|
+
construct_dict = {}
|
|
429
|
+
micro_step_dict = {Const.MEGATRON_MICRO_STEP_NUMBER: construct_dict_o.get(Const.MEGATRON_MICRO_STEP_NUMBER)}
|
|
430
|
+
del construct_dict_o[Const.MEGATRON_MICRO_STEP_NUMBER]
|
|
431
|
+
for key, value in construct_dict_o.items():
|
|
432
|
+
if isinstance(value, list):
|
|
433
|
+
if len(value) != 2:
|
|
434
|
+
logger.error(f'Parse construct json file "{os.path.basename(json_path)}" failed.')
|
|
435
|
+
raise RuntimeError()
|
|
436
|
+
construct_dict[key] = value[0]
|
|
437
|
+
micro_step_dict[key] = value[1]
|
|
438
|
+
else:
|
|
439
|
+
construct_dict[key] = value
|
|
440
|
+
micro_step_dict[key] = 0
|
|
441
|
+
return construct_dict, micro_step_dict
|
|
442
|
+
return construct_dict_o, {}
|
|
443
|
+
|
|
444
|
+
|
|
425
445
|
def save_json(json_path, data, indent=None, mode="w"):
|
|
426
446
|
check_path_before_create(json_path)
|
|
427
447
|
json_path = os.path.realpath(json_path)
|
|
@@ -520,6 +540,9 @@ def move_directory(src_path, dst_path):
|
|
|
520
540
|
check_file_or_directory_path(src_path, isdir=True)
|
|
521
541
|
check_path_before_create(dst_path)
|
|
522
542
|
try:
|
|
543
|
+
if os.path.exists(dst_path):
|
|
544
|
+
logger.warning(f"The destination directory {dst_path} already exists, it will be removed.")
|
|
545
|
+
shutil.rmtree(dst_path)
|
|
523
546
|
shutil.move(src_path, dst_path)
|
|
524
547
|
except Exception as e:
|
|
525
548
|
logger.error(f"move directory {src_path} to {dst_path} failed")
|
msprobe/core/common/log.py
CHANGED
|
@@ -89,6 +89,13 @@ class BaseLogger:
|
|
|
89
89
|
self.error(msg)
|
|
90
90
|
raise exception
|
|
91
91
|
|
|
92
|
+
def warning_log_with_exp(self, msg, exception):
|
|
93
|
+
"""
|
|
94
|
+
打印警告日志并抛出指定异常
|
|
95
|
+
"""
|
|
96
|
+
self.warning(msg)
|
|
97
|
+
raise exception
|
|
98
|
+
|
|
92
99
|
def _print_log(self, level, msg, end='\n'):
|
|
93
100
|
current_rank = self.get_rank()
|
|
94
101
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|