mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -12,16 +12,20 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
15
|
+
import atexit
|
|
16
16
|
import csv
|
|
17
17
|
import fcntl
|
|
18
|
+
import io
|
|
18
19
|
import os
|
|
20
|
+
import pickle
|
|
21
|
+
from multiprocessing import shared_memory
|
|
19
22
|
import stat
|
|
20
23
|
import json
|
|
21
24
|
import re
|
|
22
25
|
import shutil
|
|
23
|
-
|
|
24
|
-
|
|
26
|
+
import sys
|
|
27
|
+
import zipfile
|
|
28
|
+
import multiprocessing
|
|
25
29
|
import yaml
|
|
26
30
|
import numpy as np
|
|
27
31
|
import pandas as pd
|
|
@@ -29,7 +33,10 @@ import pandas as pd
|
|
|
29
33
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
34
|
from msprobe.core.common.log import logger
|
|
31
35
|
from msprobe.core.common.exceptions import FileCheckException
|
|
32
|
-
from msprobe.core.common.const import FileCheckConst
|
|
36
|
+
from msprobe.core.common.const import FileCheckConst, CompareConst
|
|
37
|
+
from msprobe.core.common.global_lock import global_lock, is_main_process
|
|
38
|
+
|
|
39
|
+
proc_lock = multiprocessing.Lock()
|
|
33
40
|
|
|
34
41
|
|
|
35
42
|
class FileChecker:
|
|
@@ -165,6 +172,12 @@ def check_path_exists(path):
|
|
|
165
172
|
if not os.path.exists(path):
|
|
166
173
|
logger.error('The file path %s does not exist.' % path)
|
|
167
174
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def check_path_not_exists(path):
|
|
178
|
+
if os.path.exists(path):
|
|
179
|
+
logger.error('The file path %s already exist.' % path)
|
|
180
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
168
181
|
|
|
169
182
|
|
|
170
183
|
def check_path_readability(path):
|
|
@@ -299,12 +312,13 @@ def check_path_before_create(path):
|
|
|
299
312
|
def check_dirpath_before_read(path):
|
|
300
313
|
path = os.path.realpath(path)
|
|
301
314
|
dirpath = os.path.dirname(path)
|
|
302
|
-
if
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
315
|
+
if dedup_log('check_dirpath_before_read', dirpath):
|
|
316
|
+
if check_others_writable(dirpath):
|
|
317
|
+
logger.warning(f"The directory is writable by others: {dirpath}.")
|
|
318
|
+
try:
|
|
319
|
+
check_path_owner_consistent(dirpath)
|
|
320
|
+
except FileCheckException:
|
|
321
|
+
logger.warning(f"The directory {dirpath} is not yours.")
|
|
308
322
|
|
|
309
323
|
|
|
310
324
|
def check_file_or_directory_path(path, isdir=False):
|
|
@@ -446,6 +460,17 @@ def save_excel(path, data):
|
|
|
446
460
|
return "list"
|
|
447
461
|
raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
|
|
448
462
|
|
|
463
|
+
def save_in_slice(df, base_name):
|
|
464
|
+
df_length = len(df)
|
|
465
|
+
if df_length < CompareConst.MAX_EXCEL_LENGTH:
|
|
466
|
+
df.to_excel(writer, sheet_name=base_name if base_name else 'Sheet1', index=False)
|
|
467
|
+
else:
|
|
468
|
+
slice_num = (df_length + CompareConst.MAX_EXCEL_LENGTH - 1) // CompareConst.MAX_EXCEL_LENGTH
|
|
469
|
+
slice_size = (df_length + slice_num - 1) // slice_num
|
|
470
|
+
for i in range(slice_num):
|
|
471
|
+
df.iloc[i * slice_size: min((i + 1) * slice_size, df_length)] \
|
|
472
|
+
.to_excel(writer, sheet_name=f'{base_name}_part_{i}' if base_name else f'part_{i}', index=False)
|
|
473
|
+
|
|
449
474
|
check_path_before_create(path)
|
|
450
475
|
path = os.path.realpath(path)
|
|
451
476
|
|
|
@@ -453,18 +478,27 @@ def save_excel(path, data):
|
|
|
453
478
|
data_type = validate_data(data)
|
|
454
479
|
|
|
455
480
|
try:
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
481
|
+
with pd.ExcelWriter(path) as writer:
|
|
482
|
+
if data_type == "single":
|
|
483
|
+
save_in_slice(data, None)
|
|
484
|
+
elif data_type == "list":
|
|
460
485
|
for data_df, sheet_name in data:
|
|
461
|
-
data_df
|
|
486
|
+
save_in_slice(data_df, sheet_name)
|
|
462
487
|
except Exception as e:
|
|
463
488
|
logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
|
|
464
489
|
raise RuntimeError(f"Save excel file {path} failed.") from e
|
|
465
490
|
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
466
491
|
|
|
467
492
|
|
|
493
|
+
def move_directory(src_path, dst_path):
|
|
494
|
+
check_file_or_directory_path(src_path, isdir=True)
|
|
495
|
+
check_path_before_create(dst_path)
|
|
496
|
+
try:
|
|
497
|
+
shutil.move(src_path, dst_path)
|
|
498
|
+
except Exception as e:
|
|
499
|
+
logger.error(f"move directory {src_path} to {dst_path} failed")
|
|
500
|
+
raise RuntimeError(f"move directory {src_path} to {dst_path} failed") from e
|
|
501
|
+
change_mode(dst_path, FileCheckConst.DATA_DIR_AUTHORITY)
|
|
468
502
|
|
|
469
503
|
|
|
470
504
|
def move_file(src_path, dst_path):
|
|
@@ -530,7 +564,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
530
564
|
if not isinstance(value, str):
|
|
531
565
|
return True
|
|
532
566
|
try:
|
|
533
|
-
# -1.00 or +1.00 should be
|
|
567
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
534
568
|
float(value)
|
|
535
569
|
except ValueError:
|
|
536
570
|
# otherwise, they will be considered as formular injections
|
|
@@ -576,7 +610,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
|
|
|
576
610
|
if not isinstance(value, str):
|
|
577
611
|
return True
|
|
578
612
|
try:
|
|
579
|
-
# -1.00 or +1.00 should be
|
|
613
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
580
614
|
float(value)
|
|
581
615
|
except ValueError:
|
|
582
616
|
# otherwise, they will be considered as formular injections
|
|
@@ -607,8 +641,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
|
|
|
607
641
|
def remove_path(path):
|
|
608
642
|
if not os.path.exists(path):
|
|
609
643
|
return
|
|
644
|
+
if os.path.islink(path):
|
|
645
|
+
logger.error(f"Failed to delete {path}, it is a symbolic link.")
|
|
646
|
+
raise RuntimeError("Delete file or directory failed.")
|
|
610
647
|
try:
|
|
611
|
-
if os.path.
|
|
648
|
+
if os.path.isfile(path):
|
|
612
649
|
os.remove(path)
|
|
613
650
|
else:
|
|
614
651
|
shutil.rmtree(path)
|
|
@@ -617,7 +654,7 @@ def remove_path(path):
|
|
|
617
654
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err
|
|
618
655
|
except Exception as e:
|
|
619
656
|
logger.error("Failed to delete {}. Please check.".format(path))
|
|
620
|
-
raise RuntimeError(
|
|
657
|
+
raise RuntimeError("Delete file or directory failed.") from e
|
|
621
658
|
|
|
622
659
|
|
|
623
660
|
def get_json_contents(file_path):
|
|
@@ -651,46 +688,238 @@ def os_walk_for_files(path, depth):
|
|
|
651
688
|
return res
|
|
652
689
|
|
|
653
690
|
|
|
654
|
-
def
|
|
691
|
+
def check_zip_file(zip_file_path):
|
|
692
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
|
693
|
+
total_size = 0
|
|
694
|
+
if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
|
|
695
|
+
raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}")
|
|
696
|
+
for file_info in zip_file.infolist():
|
|
697
|
+
if file_info.file_size > FileCheckConst.MAX_FILE_SIZE:
|
|
698
|
+
raise ValueError(f"File {file_info.filename} is too large to extract")
|
|
699
|
+
|
|
700
|
+
total_size += file_info.file_size
|
|
701
|
+
if total_size > FileCheckConst.MAX_ZIP_SIZE:
|
|
702
|
+
raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def read_xlsx(file_path, sheet_name=None):
|
|
706
|
+
check_file_or_directory_path(file_path)
|
|
707
|
+
check_zip_file(file_path)
|
|
708
|
+
try:
|
|
709
|
+
if sheet_name:
|
|
710
|
+
result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name)
|
|
711
|
+
else:
|
|
712
|
+
result_df = pd.read_excel(file_path, keep_default_na=False)
|
|
713
|
+
except Exception as e:
|
|
714
|
+
logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
|
|
715
|
+
raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
|
|
716
|
+
return result_df
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def create_file_with_list(result_list, filepath):
|
|
720
|
+
check_path_before_create(filepath)
|
|
721
|
+
filepath = os.path.realpath(filepath)
|
|
722
|
+
try:
|
|
723
|
+
with FileOpen(filepath, 'w', encoding='utf-8') as file:
|
|
724
|
+
fcntl.flock(file, fcntl.LOCK_EX)
|
|
725
|
+
for item in result_list:
|
|
726
|
+
file.write(item + '\n')
|
|
727
|
+
fcntl.flock(file, fcntl.LOCK_UN)
|
|
728
|
+
except Exception as e:
|
|
729
|
+
logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.')
|
|
730
|
+
raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e
|
|
731
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def create_file_with_content(data, filepath):
|
|
735
|
+
check_path_before_create(filepath)
|
|
736
|
+
filepath = os.path.realpath(filepath)
|
|
737
|
+
try:
|
|
738
|
+
with FileOpen(filepath, 'w', encoding='utf-8') as file:
|
|
739
|
+
fcntl.flock(file, fcntl.LOCK_EX)
|
|
740
|
+
file.write(data)
|
|
741
|
+
fcntl.flock(file, fcntl.LOCK_UN)
|
|
742
|
+
except Exception as e:
|
|
743
|
+
logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.')
|
|
744
|
+
raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e
|
|
745
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def add_file_to_zip(zip_file_path, file_path, arc_path=None):
|
|
655
749
|
"""
|
|
656
|
-
|
|
750
|
+
Add a file to a ZIP archive, if zip does not exist, create one.
|
|
657
751
|
|
|
658
|
-
|
|
659
|
-
|
|
752
|
+
:param zip_file_path: Path to the ZIP archive
|
|
753
|
+
:param file_path: Path to the file to add
|
|
754
|
+
:param arc_path: Optional path inside the ZIP archive where the file should be added
|
|
755
|
+
"""
|
|
756
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
757
|
+
check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE)
|
|
758
|
+
zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
|
|
759
|
+
if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE:
|
|
760
|
+
raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
761
|
+
check_path_before_create(zip_file_path)
|
|
762
|
+
try:
|
|
763
|
+
proc_lock.acquire()
|
|
764
|
+
with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
|
|
765
|
+
zip_file.write(file_path, arc_path)
|
|
766
|
+
except Exception as e:
|
|
767
|
+
logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.')
|
|
768
|
+
raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e
|
|
769
|
+
finally:
|
|
770
|
+
proc_lock.release()
|
|
771
|
+
change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
660
772
|
|
|
661
|
-
Parameters:
|
|
662
|
-
pem_path (str): The file path of the SSL certificate.
|
|
663
|
-
is_public_key (bool): The file is public key or not.
|
|
664
773
|
|
|
665
|
-
|
|
666
|
-
RuntimeError: If the SSL certificate is invalid or expired.
|
|
774
|
+
def create_file_in_zip(zip_file_path, file_name, content):
|
|
667
775
|
"""
|
|
668
|
-
|
|
776
|
+
Create a file with content inside a ZIP archive.
|
|
777
|
+
|
|
778
|
+
:param zip_file_path: Path to the ZIP archive
|
|
779
|
+
:param file_name: Name of the file to create
|
|
780
|
+
:param content: Content to write to the file
|
|
781
|
+
"""
|
|
782
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
783
|
+
check_path_before_create(zip_file_path)
|
|
784
|
+
zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
|
|
785
|
+
if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE:
|
|
786
|
+
raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
|
|
669
787
|
try:
|
|
670
|
-
with
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
679
|
-
f"starts from {pem_start} ends at {pem_end}.")
|
|
788
|
+
with open(zip_file_path, 'a+') as f: # 必须用 'a+' 模式才能 flock
|
|
789
|
+
# 2. 获取排他锁(阻塞直到成功)
|
|
790
|
+
fcntl.flock(f, fcntl.LOCK_EX) # LOCK_EX: 独占锁
|
|
791
|
+
with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
|
|
792
|
+
zip_info = zipfile.ZipInfo(file_name)
|
|
793
|
+
zip_info.compress_type = zipfile.ZIP_DEFLATED
|
|
794
|
+
zip_file.writestr(zip_info, content)
|
|
795
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
680
796
|
except Exception as e:
|
|
681
|
-
logger.error(
|
|
682
|
-
raise RuntimeError(f"
|
|
797
|
+
logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
|
|
798
|
+
raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
|
|
799
|
+
change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
683
800
|
|
|
684
|
-
now_utc = datetime.now(tz=timezone.utc)
|
|
685
|
-
if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
686
|
-
raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
|
|
687
801
|
|
|
802
|
+
def extract_zip(zip_file_path, extract_dir):
|
|
803
|
+
"""
|
|
804
|
+
Extract the contents of a ZIP archive to a specified directory.
|
|
688
805
|
|
|
689
|
-
|
|
690
|
-
|
|
806
|
+
:param zip_file_path: Path to the ZIP archive
|
|
807
|
+
:param extract_dir: Directory to extract the contents to
|
|
808
|
+
"""
|
|
809
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
691
810
|
try:
|
|
692
|
-
|
|
811
|
+
proc_lock.acquire()
|
|
812
|
+
check_zip_file(zip_file_path)
|
|
693
813
|
except Exception as e:
|
|
694
|
-
logger.error(f
|
|
695
|
-
raise RuntimeError(f"
|
|
696
|
-
|
|
814
|
+
logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
|
|
815
|
+
raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
|
|
816
|
+
finally:
|
|
817
|
+
proc_lock.release()
|
|
818
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
|
819
|
+
zip_file.extractall(extract_dir)
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def split_zip_file_path(zip_file_path):
|
|
823
|
+
check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
|
|
824
|
+
zip_file_path = os.path.realpath(zip_file_path)
|
|
825
|
+
return os.path.dirname(zip_file_path), os.path.basename(zip_file_path)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
def dedup_log(func_name, filter_name):
|
|
829
|
+
with SharedDict() as shared_dict:
|
|
830
|
+
exist_names = shared_dict.get(func_name, set())
|
|
831
|
+
if filter_name in exist_names:
|
|
832
|
+
return False
|
|
833
|
+
exist_names.add(filter_name)
|
|
834
|
+
shared_dict[func_name] = exist_names
|
|
835
|
+
return True
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
class SharedDict:
|
|
839
|
+
def __init__(self):
|
|
840
|
+
self._changed = False
|
|
841
|
+
self._dict = None
|
|
842
|
+
self._shm = None
|
|
843
|
+
|
|
844
|
+
def __enter__(self):
|
|
845
|
+
self._load_shared_memory()
|
|
846
|
+
return self
|
|
847
|
+
|
|
848
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
849
|
+
try:
|
|
850
|
+
if self._changed:
|
|
851
|
+
data = pickle.dumps(self._dict)
|
|
852
|
+
global_lock.acquire()
|
|
853
|
+
try:
|
|
854
|
+
self._shm.buf[0:len(data)] = bytearray(data)
|
|
855
|
+
finally:
|
|
856
|
+
global_lock.release()
|
|
857
|
+
self._shm.close()
|
|
858
|
+
except FileNotFoundError:
|
|
859
|
+
name = self.get_shared_memory_name()
|
|
860
|
+
logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.')
|
|
861
|
+
|
|
862
|
+
def __setitem__(self, key, value):
|
|
863
|
+
self._dict[key] = value
|
|
864
|
+
self._changed = True
|
|
865
|
+
|
|
866
|
+
def __contains__(self, item):
|
|
867
|
+
return item in self._dict
|
|
868
|
+
|
|
869
|
+
@classmethod
|
|
870
|
+
def destroy_shared_memory(cls):
|
|
871
|
+
if is_main_process():
|
|
872
|
+
name = cls.get_shared_memory_name()
|
|
873
|
+
try:
|
|
874
|
+
shm = shared_memory.SharedMemory(create=False, name=name)
|
|
875
|
+
shm.close()
|
|
876
|
+
shm.unlink()
|
|
877
|
+
logger.debug(f'destroy shared memory, name: {name}')
|
|
878
|
+
except FileNotFoundError:
|
|
879
|
+
logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.')
|
|
880
|
+
|
|
881
|
+
@classmethod
|
|
882
|
+
def get_shared_memory_name(cls):
|
|
883
|
+
if is_main_process():
|
|
884
|
+
return f'shared_memory_{os.getpid()}'
|
|
885
|
+
return f'shared_memory_{os.getppid()}'
|
|
886
|
+
|
|
887
|
+
def get(self, key, default=None):
|
|
888
|
+
return self._dict.get(key, default)
|
|
889
|
+
|
|
890
|
+
def _load_shared_memory(self):
|
|
891
|
+
name = self.get_shared_memory_name()
|
|
892
|
+
try:
|
|
893
|
+
self._shm = shared_memory.SharedMemory(create=False, name=name)
|
|
894
|
+
except FileNotFoundError:
|
|
895
|
+
try:
|
|
896
|
+
# 共享内存空间增加至5M
|
|
897
|
+
self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024 * 5)
|
|
898
|
+
data = pickle.dumps({})
|
|
899
|
+
self._shm.buf[0:len(data)] = bytearray(data)
|
|
900
|
+
logger.debug(f'create shared memory, name: {name}')
|
|
901
|
+
except FileExistsError:
|
|
902
|
+
self._shm = shared_memory.SharedMemory(create=False, name=name)
|
|
903
|
+
self._safe_load()
|
|
904
|
+
|
|
905
|
+
def _safe_load(self):
|
|
906
|
+
with io.BytesIO(self._shm.buf[:]) as buff:
|
|
907
|
+
try:
|
|
908
|
+
self._dict = SafeUnpickler(buff).load()
|
|
909
|
+
except Exception as e:
|
|
910
|
+
logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.')
|
|
911
|
+
self._dict = {}
|
|
912
|
+
self._shm.buf[:] = bytearray(b'\x00' * len(self._shm.buf)) # 清空内存
|
|
913
|
+
self._changed = True
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
class SafeUnpickler(pickle.Unpickler):
|
|
917
|
+
WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}}
|
|
918
|
+
|
|
919
|
+
def find_class(self, module, name):
|
|
920
|
+
if module in self.WHITELIST and name in self.WHITELIST[module]:
|
|
921
|
+
return super().find_class(module, name)
|
|
922
|
+
raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!')
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
atexit.register(SharedDict.destroy_shared_memory)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.import functools
|
|
15
|
+
import functools
|
|
16
|
+
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
18
|
+
from msprobe.core.common.file_utils import save_npy
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FrameworkDescriptor:
|
|
22
|
+
def __get__(self, instance, owner):
|
|
23
|
+
if owner._framework is None:
|
|
24
|
+
owner.import_framework()
|
|
25
|
+
return owner._framework
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FmkAdp:
|
|
29
|
+
fmk = Const.PT_FRAMEWORK
|
|
30
|
+
supported_fmk = [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]
|
|
31
|
+
supported_dtype_list = ["bfloat16", "float16", "float32", "float64"]
|
|
32
|
+
_framework = None
|
|
33
|
+
framework = FrameworkDescriptor()
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def import_framework(cls):
|
|
37
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
38
|
+
import torch
|
|
39
|
+
cls._framework = torch
|
|
40
|
+
elif cls.fmk == Const.MS_FRAMEWORK:
|
|
41
|
+
import mindspore
|
|
42
|
+
cls._framework = mindspore
|
|
43
|
+
else:
|
|
44
|
+
raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def set_fmk(cls, fmk=Const.PT_FRAMEWORK):
|
|
48
|
+
if fmk not in cls.supported_fmk:
|
|
49
|
+
raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
|
|
50
|
+
cls.fmk = fmk
|
|
51
|
+
cls._framework = None # 重置框架,以便下次访问时重新导入
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def get_rank(cls):
|
|
55
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
56
|
+
return cls.framework.distributed.get_rank()
|
|
57
|
+
return cls.framework.communication.get_rank()
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_rank_id(cls):
|
|
61
|
+
if cls.is_initialized():
|
|
62
|
+
return cls.get_rank()
|
|
63
|
+
return 0
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def is_initialized(cls):
|
|
67
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
68
|
+
return cls.framework.distributed.is_initialized()
|
|
69
|
+
return cls.framework.communication.GlobalComm.INITED
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def is_nn_module(cls, module):
|
|
73
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
74
|
+
return isinstance(module, cls.framework.nn.Module)
|
|
75
|
+
return isinstance(module, cls.framework.nn.Cell)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def is_tensor(cls, tensor):
|
|
79
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
80
|
+
return isinstance(tensor, cls.framework.Tensor)
|
|
81
|
+
return isinstance(tensor, cls.framework.Tensor)
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def process_tensor(cls, tensor, func):
|
|
85
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
86
|
+
if not tensor.is_floating_point() or tensor.dtype == cls.framework.float64:
|
|
87
|
+
tensor = tensor.float()
|
|
88
|
+
return float(func(tensor))
|
|
89
|
+
return float(func(tensor).asnumpy())
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def tensor_max(cls, tensor):
|
|
93
|
+
return cls.process_tensor(tensor, lambda x: x.max())
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def tensor_min(cls, tensor):
|
|
97
|
+
return cls.process_tensor(tensor, lambda x: x.min())
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def tensor_mean(cls, tensor):
|
|
101
|
+
return cls.process_tensor(tensor, lambda x: x.mean())
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def tensor_norm(cls, tensor):
|
|
105
|
+
return cls.process_tensor(tensor, lambda x: x.norm())
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def save_tensor(cls, tensor, filepath):
|
|
109
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
110
|
+
tensor_npy = tensor.cpu().detach().float().numpy()
|
|
111
|
+
else:
|
|
112
|
+
tensor_npy = tensor.asnumpy()
|
|
113
|
+
save_npy(tensor_npy, filepath)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def dtype(cls, dtype_str):
|
|
117
|
+
if dtype_str not in cls.supported_dtype_list:
|
|
118
|
+
raise Exception(f"{dtype_str} is not supported by adapter, not in {cls.supported_dtype_list}")
|
|
119
|
+
return getattr(cls.framework, dtype_str)
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def named_parameters(cls, module):
|
|
123
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
124
|
+
if not isinstance(module, cls.framework.nn.Module):
|
|
125
|
+
raise Exception(f"{module} is not a torch.nn.Module")
|
|
126
|
+
return module.named_parameters()
|
|
127
|
+
if not isinstance(module, cls.framework.nn.Cell):
|
|
128
|
+
raise Exception(f"{module} is not a mindspore.nn.Cell")
|
|
129
|
+
return module.parameters_and_names()
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def register_forward_pre_hook(cls, module, hook, with_kwargs=False):
|
|
133
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
134
|
+
if not isinstance(module, cls.framework.nn.Module):
|
|
135
|
+
raise Exception(f"{module} is not a torch.nn.Module")
|
|
136
|
+
module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
|
|
137
|
+
else:
|
|
138
|
+
if not isinstance(module, cls.framework.nn.Cell):
|
|
139
|
+
raise Exception(f"{module} is not a mindspore.nn.Cell")
|
|
140
|
+
original_construct = module.construct
|
|
141
|
+
|
|
142
|
+
@functools.wraps(original_construct)
|
|
143
|
+
def new_construct(*args, **kwargs):
|
|
144
|
+
if with_kwargs:
|
|
145
|
+
hook(module, args, kwargs)
|
|
146
|
+
else:
|
|
147
|
+
hook(module, args)
|
|
148
|
+
return original_construct(*args, **kwargs)
|
|
149
|
+
|
|
150
|
+
module.construct = new_construct
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def load_checkpoint(cls, path, to_cpu=True, weights_only=True):
|
|
154
|
+
check_file_or_directory_path(path)
|
|
155
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
156
|
+
try:
|
|
157
|
+
if to_cpu:
|
|
158
|
+
return cls.framework.load(path, map_location=cls.framework.device("cpu"), weights_only=weights_only)
|
|
159
|
+
else:
|
|
160
|
+
return cls.framework.load(path, weights_only=weights_only)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
raise RuntimeError(f"load pt file {path} failed: {e}") from e
|
|
163
|
+
return mindspore.load_checkpoint(path)
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def asnumpy(cls, tensor):
|
|
167
|
+
if cls.fmk == Const.PT_FRAMEWORK:
|
|
168
|
+
return tensor.float().numpy()
|
|
169
|
+
return tensor.float().asnumpy()
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import multiprocessing
|
|
17
|
+
from multiprocessing.shared_memory import SharedMemory
|
|
18
|
+
import random
|
|
19
|
+
import time
|
|
20
|
+
import atexit
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_main_process():
|
|
27
|
+
return multiprocessing.current_process().name == 'MainProcess'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GlobalLock:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self.name = self.get_lock_name()
|
|
33
|
+
try:
|
|
34
|
+
self._shm = SharedMemory(create=False, name=self.name)
|
|
35
|
+
time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁
|
|
36
|
+
except FileNotFoundError:
|
|
37
|
+
try:
|
|
38
|
+
self._shm = SharedMemory(create=True, name=self.name, size=1)
|
|
39
|
+
self._shm.buf[0] = 0
|
|
40
|
+
logger.debug(f'{self.name} is created.')
|
|
41
|
+
except FileExistsError:
|
|
42
|
+
self.__init__()
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def get_lock_name(cls):
|
|
46
|
+
if is_main_process():
|
|
47
|
+
return f'global_lock_{os.getpid()}'
|
|
48
|
+
return f'global_lock_{os.getppid()}'
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def is_lock_exist(cls):
|
|
52
|
+
try:
|
|
53
|
+
SharedMemory(create=False, name=cls.get_lock_name()).close()
|
|
54
|
+
return True
|
|
55
|
+
except FileNotFoundError:
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def cleanup(self):
|
|
59
|
+
self._shm.close()
|
|
60
|
+
if is_main_process():
|
|
61
|
+
try:
|
|
62
|
+
self._shm.unlink()
|
|
63
|
+
logger.debug(f'{self.name} is unlinked.')
|
|
64
|
+
except FileNotFoundError:
|
|
65
|
+
logger.warning(f'{self.name} has already been unlinked.')
|
|
66
|
+
|
|
67
|
+
def acquire(self, timeout=180):
|
|
68
|
+
"""
|
|
69
|
+
acquire global lock, default timeout is 3 minutes.
|
|
70
|
+
|
|
71
|
+
:param float timeout: timeout(seconds), default value is 180.
|
|
72
|
+
"""
|
|
73
|
+
start = time.time()
|
|
74
|
+
while time.time() - start < timeout:
|
|
75
|
+
if self._shm.buf[0] == 0:
|
|
76
|
+
self._shm.buf[0] = 1
|
|
77
|
+
return
|
|
78
|
+
time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms
|
|
79
|
+
self._shm.buf[0] = 1
|
|
80
|
+
|
|
81
|
+
def release(self):
|
|
82
|
+
self._shm.buf[0] = 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
global_lock = GlobalLock()
|
|
86
|
+
atexit.register(global_lock.cleanup)
|