mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from functools import wraps
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
# 记录工具函数递归的深度
|
|
24
|
+
recursion_depth = defaultdict(int)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
|
|
28
|
+
"""装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
|
|
29
|
+
def decorator(func):
|
|
30
|
+
@wraps(func)
|
|
31
|
+
def wrapper(*args, **kwargs):
|
|
32
|
+
func_id = id(func)
|
|
33
|
+
recursion_depth[func_id] += 1
|
|
34
|
+
if recursion_depth[func_id] > max_depth:
|
|
35
|
+
msg = f"call {func_info} exceeds the recursion limit."
|
|
36
|
+
logger.error_log_with_exp(
|
|
37
|
+
msg,
|
|
38
|
+
MsprobeException(
|
|
39
|
+
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
try:
|
|
43
|
+
result = func(*args, **kwargs)
|
|
44
|
+
finally:
|
|
45
|
+
recursion_depth[func_id] -= 1
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
return wrapper
|
|
49
|
+
|
|
50
|
+
return decorator
|
|
@@ -21,19 +21,21 @@ class CodedException(Exception):
|
|
|
21
21
|
|
|
22
22
|
def __str__(self):
|
|
23
23
|
return self.error_info
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
|
|
25
|
+
|
|
26
26
|
class MsprobeException(CodedException):
|
|
27
27
|
INVALID_PARAM_ERROR = 0
|
|
28
28
|
OVERFLOW_NUMS_ERROR = 1
|
|
29
29
|
RECURSION_LIMIT_ERROR = 2
|
|
30
30
|
INTERFACE_USAGE_ERROR = 3
|
|
31
|
+
UNSUPPORTED_TYPE_ERROR = 4
|
|
31
32
|
|
|
32
33
|
err_strs = {
|
|
33
34
|
INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
|
|
34
35
|
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
|
|
35
36
|
RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
|
|
36
|
-
INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
|
|
37
|
+
INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ",
|
|
38
|
+
UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: "
|
|
37
39
|
}
|
|
38
40
|
|
|
39
41
|
|
|
@@ -12,23 +12,31 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
15
|
+
import atexit
|
|
16
16
|
import csv
|
|
17
17
|
import fcntl
|
|
18
|
+
import io
|
|
18
19
|
import os
|
|
20
|
+
import pickle
|
|
21
|
+
from multiprocessing import shared_memory
|
|
19
22
|
import stat
|
|
20
23
|
import json
|
|
21
24
|
import re
|
|
22
25
|
import shutil
|
|
23
|
-
|
|
24
|
-
|
|
26
|
+
import sys
|
|
27
|
+
import zipfile
|
|
28
|
+
import multiprocessing
|
|
25
29
|
import yaml
|
|
26
30
|
import numpy as np
|
|
27
31
|
import pandas as pd
|
|
28
32
|
|
|
33
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
29
34
|
from msprobe.core.common.log import logger
|
|
30
35
|
from msprobe.core.common.exceptions import FileCheckException
|
|
31
36
|
from msprobe.core.common.const import FileCheckConst
|
|
37
|
+
from msprobe.core.common.global_lock import global_lock, is_main_process
|
|
38
|
+
|
|
39
|
+
proc_lock = multiprocessing.Lock()
|
|
32
40
|
|
|
33
41
|
|
|
34
42
|
class FileChecker:
|
|
@@ -164,6 +172,12 @@ def check_path_exists(path):
|
|
|
164
172
|
if not os.path.exists(path):
|
|
165
173
|
logger.error('The file path %s does not exist.' % path)
|
|
166
174
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def check_path_not_exists(path):
|
|
178
|
+
if os.path.exists(path):
|
|
179
|
+
logger.error('The file path %s already exist.' % path)
|
|
180
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
167
181
|
|
|
168
182
|
|
|
169
183
|
def check_path_readability(path):
|
|
@@ -266,6 +280,7 @@ def make_dir(dir_path):
|
|
|
266
280
|
file_check.common_check()
|
|
267
281
|
|
|
268
282
|
|
|
283
|
+
@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
|
|
269
284
|
def create_directory(dir_path):
|
|
270
285
|
"""
|
|
271
286
|
Function Description:
|
|
@@ -297,12 +312,13 @@ def check_path_before_create(path):
|
|
|
297
312
|
def check_dirpath_before_read(path):
|
|
298
313
|
path = os.path.realpath(path)
|
|
299
314
|
dirpath = os.path.dirname(path)
|
|
300
|
-
if
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
315
|
+
if dedup_log('check_dirpath_before_read', dirpath):
|
|
316
|
+
if check_others_writable(dirpath):
|
|
317
|
+
logger.warning(f"The directory is writable by others: {dirpath}.")
|
|
318
|
+
try:
|
|
319
|
+
check_path_owner_consistent(dirpath)
|
|
320
|
+
except FileCheckException:
|
|
321
|
+
logger.warning(f"The directory {dirpath} is not yours.")
|
|
306
322
|
|
|
307
323
|
|
|
308
324
|
def check_file_or_directory_path(path, isdir=False):
|
|
@@ -332,6 +348,23 @@ def change_mode(path, mode):
|
|
|
332
348
|
'Failed to change {} authority. {}'.format(path, str(ex))) from ex
|
|
333
349
|
|
|
334
350
|
|
|
351
|
+
@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod')
|
|
352
|
+
def recursive_chmod(path):
|
|
353
|
+
"""
|
|
354
|
+
递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750
|
|
355
|
+
|
|
356
|
+
:param path: 要修改权限的目录路径
|
|
357
|
+
"""
|
|
358
|
+
for _, dirs, files in os.walk(path):
|
|
359
|
+
for file_name in files:
|
|
360
|
+
file_path = os.path.join(path, file_name)
|
|
361
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
362
|
+
for dir_name in dirs:
|
|
363
|
+
dir_path = os.path.join(path, dir_name)
|
|
364
|
+
change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY)
|
|
365
|
+
recursive_chmod(dir_path)
|
|
366
|
+
|
|
367
|
+
|
|
335
368
|
def path_len_exceeds_limit(file_path):
|
|
336
369
|
return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
|
|
337
370
|
len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
|
|
@@ -446,6 +479,15 @@ def save_excel(path, data):
|
|
|
446
479
|
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
447
480
|
|
|
448
481
|
|
|
482
|
+
def move_directory(src_path, dst_path):
|
|
483
|
+
check_file_or_directory_path(src_path, isdir=True)
|
|
484
|
+
check_path_before_create(dst_path)
|
|
485
|
+
try:
|
|
486
|
+
shutil.move(src_path, dst_path)
|
|
487
|
+
except Exception as e:
|
|
488
|
+
logger.error(f"move directory {src_path} to {dst_path} failed")
|
|
489
|
+
raise RuntimeError(f"move directory {src_path} to {dst_path} failed") from e
|
|
490
|
+
change_mode(dst_path, FileCheckConst.DATA_DIR_AUTHORITY)
|
|
449
491
|
|
|
450
492
|
|
|
451
493
|
def move_file(src_path, dst_path):
|
|
@@ -511,7 +553,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
511
553
|
if not isinstance(value, str):
|
|
512
554
|
return True
|
|
513
555
|
try:
|
|
514
|
-
# -1.00 or +1.00 should be
|
|
556
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
515
557
|
float(value)
|
|
516
558
|
except ValueError:
|
|
517
559
|
# otherwise, they will be considered as formular injections
|
|
@@ -557,7 +599,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
|
|
|
557
599
|
if not isinstance(value, str):
|
|
558
600
|
return True
|
|
559
601
|
try:
|
|
560
|
-
# -1.00 or +1.00 should be
|
|
602
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
561
603
|
float(value)
|
|
562
604
|
except ValueError:
|
|
563
605
|
# otherwise, they will be considered as formular injections
|
|
@@ -588,8 +630,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
|
|
|
588
630
|
def remove_path(path):
|
|
589
631
|
if not os.path.exists(path):
|
|
590
632
|
return
|
|
633
|
+
if os.path.islink(path):
|
|
634
|
+
logger.error(f"Failed to delete {path}, it is a symbolic link.")
|
|
635
|
+
raise RuntimeError("Delete file or directory failed.")
|
|
591
636
|
try:
|
|
592
|
-
if os.path.
|
|
637
|
+
if os.path.isfile(path):
|
|
593
638
|
os.remove(path)
|
|
594
639
|
else:
|
|
595
640
|
shutil.rmtree(path)
|
|
@@ -598,7 +643,7 @@ def remove_path(path):
|
|
|
598
643
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err
|
|
599
644
|
except Exception as e:
|
|
600
645
|
logger.error("Failed to delete {}. Please check.".format(path))
|
|
601
|
-
raise RuntimeError(
|
|
646
|
+
raise RuntimeError("Delete file or directory failed.") from e
|
|
602
647
|
|
|
603
648
|
|
|
604
649
|
def get_json_contents(file_path):
|
|
@@ -632,42 +677,231 @@ def os_walk_for_files(path, depth):
|
|
|
632
677
|
return res
|
|
633
678
|
|
|
634
679
|
|
|
635
|
-
def
|
|
680
|
+
def read_xlsx(file_path, sheet_name=None):
|
|
681
|
+
check_file_or_directory_path(file_path)
|
|
682
|
+
try:
|
|
683
|
+
if sheet_name:
|
|
684
|
+
result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name)
|
|
685
|
+
else:
|
|
686
|
+
result_df = pd.read_excel(file_path, keep_default_na=False)
|
|
687
|
+
except Exception as e:
|
|
688
|
+
logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
|
|
689
|
+
raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
|
|
690
|
+
return result_df
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def create_file_with_list(result_list, filepath):
|
|
694
|
+
check_path_before_create(filepath)
|
|
695
|
+
filepath = os.path.realpath(filepath)
|
|
696
|
+
try:
|
|
697
|
+
with FileOpen(filepath, 'w', encoding='utf-8') as file:
|
|
698
|
+
fcntl.flock(file, fcntl.LOCK_EX)
|
|
699
|
+
for item in result_list:
|
|
700
|
+
file.write(item + '\n')
|
|
701
|
+
fcntl.flock(file, fcntl.LOCK_UN)
|
|
702
|
+
except Exception as e:
|
|
703
|
+
logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.')
|
|
704
|
+
raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e
|
|
705
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def create_file_with_content(data, filepath):
|
|
709
|
+
check_path_before_create(filepath)
|
|
710
|
+
filepath = os.path.realpath(filepath)
|
|
711
|
+
try:
|
|
712
|
+
with FileOpen(filepath, 'w', encoding='utf-8') as file:
|
|
713
|
+
fcntl.flock(file, fcntl.LOCK_EX)
|
|
714
|
+
file.write(data)
|
|
715
|
+
fcntl.flock(file, fcntl.LOCK_UN)
|
|
716
|
+
except Exception as e:
|
|
717
|
+
logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.')
|
|
718
|
+
raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e
|
|
719
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def add_file_to_zip(zip_file_path, file_path, arc_path=None):
|
|
636
723
|
"""
|
|
637
|
-
|
|
724
|
+
Add a file to a ZIP archive, if zip does not exist, create one.
|
|
638
725
|
|
|
639
|
-
|
|
640
|
-
|
|
726
|
+
:param zip_file_path: Path to the ZIP archive
|
|
727
|
+
:param file_path: Path to the file to add
|
|
728
|
+
:param arc_path: Optional path inside the ZIP archive where the file should be added
|
|
729
|
+
"""
|
|
730
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
731
|
+
check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE)
|
|
732
|
+
zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
|
|
733
|
+
if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE:
|
|
734
|
+
raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
735
|
+
check_path_before_create(zip_file_path)
|
|
736
|
+
try:
|
|
737
|
+
proc_lock.acquire()
|
|
738
|
+
with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
|
|
739
|
+
zip_file.write(file_path, arc_path)
|
|
740
|
+
except Exception as e:
|
|
741
|
+
logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.')
|
|
742
|
+
raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e
|
|
743
|
+
finally:
|
|
744
|
+
proc_lock.release()
|
|
745
|
+
change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
641
746
|
|
|
642
|
-
Parameters:
|
|
643
|
-
pem_path (str): The file path of the SSL certificate.
|
|
644
747
|
|
|
645
|
-
|
|
646
|
-
RuntimeError: If the SSL certificate is invalid or expired.
|
|
748
|
+
def create_file_in_zip(zip_file_path, file_name, content):
|
|
647
749
|
"""
|
|
648
|
-
|
|
750
|
+
Create a file with content inside a ZIP archive.
|
|
751
|
+
|
|
752
|
+
:param zip_file_path: Path to the ZIP archive
|
|
753
|
+
:param file_name: Name of the file to create
|
|
754
|
+
:param content: Content to write to the file
|
|
755
|
+
"""
|
|
756
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
757
|
+
check_path_before_create(zip_file_path)
|
|
758
|
+
zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
|
|
759
|
+
if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE:
|
|
760
|
+
raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
649
761
|
try:
|
|
650
|
-
with
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
762
|
+
with open(zip_file_path, 'a+') as f: # 必须用 'a+' 模式才能 flock
|
|
763
|
+
# 2. 获取排他锁(阻塞直到成功)
|
|
764
|
+
fcntl.flock(f, fcntl.LOCK_EX) # LOCK_EX: 独占锁
|
|
765
|
+
with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
|
|
766
|
+
zip_info = zipfile.ZipInfo(file_name)
|
|
767
|
+
zip_info.compress_type = zipfile.ZIP_DEFLATED
|
|
768
|
+
zip_file.writestr(zip_info, content)
|
|
769
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
657
770
|
except Exception as e:
|
|
658
|
-
logger.error(
|
|
659
|
-
raise RuntimeError(f"
|
|
771
|
+
logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
|
|
772
|
+
raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
|
|
773
|
+
change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
660
774
|
|
|
661
|
-
now_utc = datetime.now(tz=timezone.utc)
|
|
662
|
-
if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
663
|
-
raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
|
|
664
775
|
|
|
776
|
+
def extract_zip(zip_file_path, extract_dir):
|
|
777
|
+
"""
|
|
778
|
+
Extract the contents of a ZIP archive to a specified directory.
|
|
665
779
|
|
|
666
|
-
|
|
667
|
-
|
|
780
|
+
:param zip_file_path: Path to the ZIP archive
|
|
781
|
+
:param extract_dir: Directory to extract the contents to
|
|
782
|
+
"""
|
|
783
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
668
784
|
try:
|
|
669
|
-
|
|
785
|
+
proc_lock.acquire()
|
|
786
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
|
787
|
+
total_size = 0
|
|
788
|
+
if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
|
|
789
|
+
raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}")
|
|
790
|
+
for file_info in zip_file.infolist():
|
|
791
|
+
if file_info.file_size > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
|
|
792
|
+
raise ValueError(f"File {file_info.filename} is too large to extract")
|
|
793
|
+
|
|
794
|
+
total_size += file_info.file_size
|
|
795
|
+
if total_size > FileCheckConst.MAX_ZIP_SIZE:
|
|
796
|
+
raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
670
797
|
except Exception as e:
|
|
671
|
-
logger.error(f
|
|
672
|
-
raise RuntimeError(f"
|
|
673
|
-
|
|
798
|
+
logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
|
|
799
|
+
raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
|
|
800
|
+
finally:
|
|
801
|
+
proc_lock.release()
|
|
802
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
|
803
|
+
zip_file.extractall(extract_dir)
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def split_zip_file_path(zip_file_path):
|
|
807
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
808
|
+
zip_file_path = os.path.realpath(zip_file_path)
|
|
809
|
+
return os.path.dirname(zip_file_path), os.path.basename(zip_file_path)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
def dedup_log(func_name, filter_name):
|
|
813
|
+
with SharedDict() as shared_dict:
|
|
814
|
+
exist_names = shared_dict.get(func_name, set())
|
|
815
|
+
if filter_name in exist_names:
|
|
816
|
+
return False
|
|
817
|
+
exist_names.add(filter_name)
|
|
818
|
+
shared_dict[func_name] = exist_names
|
|
819
|
+
return True
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
class SharedDict:
|
|
823
|
+
def __init__(self):
|
|
824
|
+
self._changed = False
|
|
825
|
+
self._dict = None
|
|
826
|
+
self._shm = None
|
|
827
|
+
|
|
828
|
+
def __enter__(self):
|
|
829
|
+
self._load_shared_memory()
|
|
830
|
+
return self
|
|
831
|
+
|
|
832
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
833
|
+
try:
|
|
834
|
+
if self._changed:
|
|
835
|
+
data = pickle.dumps(self._dict)
|
|
836
|
+
global_lock.acquire()
|
|
837
|
+
try:
|
|
838
|
+
self._shm.buf[0:len(data)] = bytearray(data)
|
|
839
|
+
finally:
|
|
840
|
+
global_lock.release()
|
|
841
|
+
self._shm.close()
|
|
842
|
+
except FileNotFoundError:
|
|
843
|
+
name = self.get_shared_memory_name()
|
|
844
|
+
logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.')
|
|
845
|
+
|
|
846
|
+
def __setitem__(self, key, value):
|
|
847
|
+
self._dict[key] = value
|
|
848
|
+
self._changed = True
|
|
849
|
+
|
|
850
|
+
def __contains__(self, item):
|
|
851
|
+
return item in self._dict
|
|
852
|
+
|
|
853
|
+
@classmethod
|
|
854
|
+
def destroy_shared_memory(cls):
|
|
855
|
+
if is_main_process():
|
|
856
|
+
name = cls.get_shared_memory_name()
|
|
857
|
+
try:
|
|
858
|
+
shm = shared_memory.SharedMemory(create=False, name=name)
|
|
859
|
+
shm.close()
|
|
860
|
+
shm.unlink()
|
|
861
|
+
logger.debug(f'destroy shared memory, name: {name}')
|
|
862
|
+
except FileNotFoundError:
|
|
863
|
+
logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.')
|
|
864
|
+
|
|
865
|
+
@classmethod
|
|
866
|
+
def get_shared_memory_name(cls):
|
|
867
|
+
if is_main_process():
|
|
868
|
+
return f'shared_memory_{os.getpid()}'
|
|
869
|
+
return f'shared_memory_{os.getppid()}'
|
|
870
|
+
|
|
871
|
+
def get(self, key, default=None):
|
|
872
|
+
return self._dict.get(key, default)
|
|
873
|
+
|
|
874
|
+
def _load_shared_memory(self):
|
|
875
|
+
name = self.get_shared_memory_name()
|
|
876
|
+
try:
|
|
877
|
+
self._shm = shared_memory.SharedMemory(create=False, name=name)
|
|
878
|
+
except FileNotFoundError:
|
|
879
|
+
try:
|
|
880
|
+
self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024)
|
|
881
|
+
data = pickle.dumps({})
|
|
882
|
+
self._shm.buf[0:len(data)] = bytearray(data)
|
|
883
|
+
logger.debug(f'create shared memory, name: {name}')
|
|
884
|
+
except FileExistsError:
|
|
885
|
+
self._shm = shared_memory.SharedMemory(create=False, name=name)
|
|
886
|
+
self._safe_load()
|
|
887
|
+
|
|
888
|
+
def _safe_load(self):
|
|
889
|
+
with io.BytesIO(self._shm.buf[:]) as buff:
|
|
890
|
+
try:
|
|
891
|
+
self._dict = SafeUnpickler(buff).load()
|
|
892
|
+
except Exception as e:
|
|
893
|
+
logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.')
|
|
894
|
+
self._dict = {}
|
|
895
|
+
self._changed = True
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
class SafeUnpickler(pickle.Unpickler):
|
|
899
|
+
WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}}
|
|
900
|
+
|
|
901
|
+
def find_class(self, module, name):
|
|
902
|
+
if module in self.WHITELIST and name in self.WHITELIST[module]:
|
|
903
|
+
return super().find_class(module, name)
|
|
904
|
+
raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!')
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
atexit.register(SharedDict.destroy_shared_memory)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.import functools
|
|
15
|
+
import functools
|
|
16
|
+
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
18
|
+
from msprobe.core.common.file_utils import save_npy
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FrameworkDescriptor:
|
|
22
|
+
def __get__(self, instance, owner):
|
|
23
|
+
if owner._framework is None:
|
|
24
|
+
owner.import_framework()
|
|
25
|
+
return owner._framework
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FmkAdp:
|
|
29
|
+
fmk = Const.PT_FRAMEWORK
|
|
30
|
+
supported_fmk = [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]
|
|
31
|
+
supported_dtype_list = ["bfloat16", "float16", "float32", "float64"]
|
|
32
|
+
_framework = None
|
|
33
|
+
framework = FrameworkDescriptor()
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def import_framework(cls):
|
|
37
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
38
|
+
import torch
|
|
39
|
+
cls._framework = torch
|
|
40
|
+
elif cls.fmk == Const.MS_FRAMEWORK:
|
|
41
|
+
import mindspore
|
|
42
|
+
cls._framework = mindspore
|
|
43
|
+
else:
|
|
44
|
+
raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def set_fmk(cls, fmk=Const.PT_FRAMEWORK):
|
|
48
|
+
if fmk not in cls.supported_fmk:
|
|
49
|
+
raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
|
|
50
|
+
cls.fmk = fmk
|
|
51
|
+
cls._framework = None # 重置框架,以便下次访问时重新导入
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def get_rank(cls):
|
|
55
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
56
|
+
return cls.framework.distributed.get_rank()
|
|
57
|
+
return cls.framework.communication.get_rank()
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_rank_id(cls):
|
|
61
|
+
if cls.is_initialized():
|
|
62
|
+
return cls.get_rank()
|
|
63
|
+
return 0
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def is_initialized(cls):
|
|
67
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
68
|
+
return cls.framework.distributed.is_initialized()
|
|
69
|
+
return cls.framework.communication.GlobalComm.INITED
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def is_nn_module(cls, module):
|
|
73
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
74
|
+
return isinstance(module, cls.framework.nn.Module)
|
|
75
|
+
return isinstance(module, cls.framework.nn.Cell)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def is_tensor(cls, tensor):
|
|
79
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
80
|
+
return isinstance(tensor, cls.framework.Tensor)
|
|
81
|
+
return isinstance(tensor, cls.framework.Tensor)
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def process_tensor(cls, tensor, func):
|
|
85
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
86
|
+
if not tensor.is_floating_point() or tensor.dtype == cls.framework.float64:
|
|
87
|
+
tensor = tensor.float()
|
|
88
|
+
return float(func(tensor))
|
|
89
|
+
return float(func(tensor).asnumpy())
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def tensor_max(cls, tensor):
|
|
93
|
+
return cls.process_tensor(tensor, lambda x: x.max())
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def tensor_min(cls, tensor):
|
|
97
|
+
return cls.process_tensor(tensor, lambda x: x.min())
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def tensor_mean(cls, tensor):
|
|
101
|
+
return cls.process_tensor(tensor, lambda x: x.mean())
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def tensor_norm(cls, tensor):
|
|
105
|
+
return cls.process_tensor(tensor, lambda x: x.norm())
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def save_tensor(cls, tensor, filepath):
|
|
109
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
110
|
+
tensor_npy = tensor.cpu().detach().float().numpy()
|
|
111
|
+
else:
|
|
112
|
+
tensor_npy = tensor.asnumpy()
|
|
113
|
+
save_npy(tensor_npy, filepath)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def dtype(cls, dtype_str):
|
|
117
|
+
if dtype_str not in cls.supported_dtype_list:
|
|
118
|
+
raise Exception(f"{dtype_str} is not supported by adapter, not in {cls.supported_dtype_list}")
|
|
119
|
+
return getattr(cls.framework, dtype_str)
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def named_parameters(cls, module):
|
|
123
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
124
|
+
if not isinstance(module, cls.framework.nn.Module):
|
|
125
|
+
raise Exception(f"{module} is not a torch.nn.Module")
|
|
126
|
+
return module.named_parameters()
|
|
127
|
+
if not isinstance(module, cls.framework.nn.Cell):
|
|
128
|
+
raise Exception(f"{module} is not a mindspore.nn.Cell")
|
|
129
|
+
return module.parameters_and_names()
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def register_forward_pre_hook(cls, module, hook, with_kwargs=False):
|
|
133
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
134
|
+
if not isinstance(module, cls.framework.nn.Module):
|
|
135
|
+
raise Exception(f"{module} is not a torch.nn.Module")
|
|
136
|
+
module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
|
|
137
|
+
else:
|
|
138
|
+
if not isinstance(module, cls.framework.nn.Cell):
|
|
139
|
+
raise Exception(f"{module} is not a mindspore.nn.Cell")
|
|
140
|
+
original_construct = module.construct
|
|
141
|
+
|
|
142
|
+
@functools.wraps(original_construct)
|
|
143
|
+
def new_construct(*args, **kwargs):
|
|
144
|
+
if with_kwargs:
|
|
145
|
+
hook(module, args, kwargs)
|
|
146
|
+
else:
|
|
147
|
+
hook(module, args)
|
|
148
|
+
return original_construct(*args, **kwargs)
|
|
149
|
+
|
|
150
|
+
module.construct = new_construct
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def load_checkpoint(cls, path, to_cpu=True, weights_only=True):
|
|
154
|
+
check_file_or_directory_path(path)
|
|
155
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
156
|
+
try:
|
|
157
|
+
if to_cpu:
|
|
158
|
+
return cls.framework.load(path, map_location=cls.framework.device("cpu"), weights_only=weights_only)
|
|
159
|
+
else:
|
|
160
|
+
return cls.framework.load(path, weights_only=weights_only)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
raise RuntimeError(f"load pt file {path} failed: {e}") from e
|
|
163
|
+
return mindspore.load_checkpoint(path)
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def asnumpy(cls, tensor):
|
|
167
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
168
|
+
return tensor.float().numpy()
|
|
169
|
+
return tensor.float().asnumpy()
|