mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,13 +12,16 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import csv
|
|
18
17
|
import fcntl
|
|
19
18
|
import os
|
|
19
|
+
import stat
|
|
20
20
|
import json
|
|
21
21
|
import re
|
|
22
22
|
import shutil
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
|
+
from dateutil import parser
|
|
23
25
|
import yaml
|
|
24
26
|
import numpy as np
|
|
25
27
|
import pandas as pd
|
|
@@ -67,9 +69,11 @@ class FileChecker:
|
|
|
67
69
|
self.check_path_ability()
|
|
68
70
|
if self.is_script:
|
|
69
71
|
check_path_owner_consistent(self.file_path)
|
|
70
|
-
|
|
72
|
+
check_path_pattern_valid(self.file_path)
|
|
71
73
|
check_common_file_size(self.file_path)
|
|
72
74
|
check_file_suffix(self.file_path, self.file_type)
|
|
75
|
+
if self.path_type == FileCheckConst.FILE:
|
|
76
|
+
check_dirpath_before_read(self.file_path)
|
|
73
77
|
return self.file_path
|
|
74
78
|
|
|
75
79
|
def check_path_ability(self):
|
|
@@ -122,9 +126,10 @@ class FileOpen:
|
|
|
122
126
|
self.file_path = os.path.realpath(self.file_path)
|
|
123
127
|
check_path_length(self.file_path)
|
|
124
128
|
self.check_ability_and_owner()
|
|
125
|
-
|
|
129
|
+
check_path_pattern_valid(self.file_path)
|
|
126
130
|
if os.path.exists(self.file_path):
|
|
127
131
|
check_common_file_size(self.file_path)
|
|
132
|
+
check_dirpath_before_read(self.file_path)
|
|
128
133
|
|
|
129
134
|
def check_ability_and_owner(self):
|
|
130
135
|
if self.mode in self.SUPPORT_READ_MODE:
|
|
@@ -193,7 +198,7 @@ def check_path_owner_consistent(path):
|
|
|
193
198
|
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
194
199
|
|
|
195
200
|
|
|
196
|
-
def
|
|
201
|
+
def check_path_pattern_valid(path):
|
|
197
202
|
if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
|
|
198
203
|
logger.error('The file path %s contains special characters.' % (path))
|
|
199
204
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
@@ -217,7 +222,6 @@ def check_common_file_size(file_path):
|
|
|
217
222
|
check_file_size(file_path, max_size)
|
|
218
223
|
return
|
|
219
224
|
check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
|
|
220
|
-
|
|
221
225
|
|
|
222
226
|
|
|
223
227
|
def check_file_suffix(file_path, file_suffix):
|
|
@@ -238,9 +242,18 @@ def check_path_type(file_path, file_type):
|
|
|
238
242
|
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
239
243
|
|
|
240
244
|
|
|
245
|
+
def check_others_writable(directory):
|
|
246
|
+
dir_stat = os.stat(directory)
|
|
247
|
+
is_writable = (
|
|
248
|
+
bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
|
|
249
|
+
bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
|
|
250
|
+
)
|
|
251
|
+
return is_writable
|
|
252
|
+
|
|
253
|
+
|
|
241
254
|
def make_dir(dir_path):
|
|
242
|
-
dir_path = os.path.realpath(dir_path)
|
|
243
255
|
check_path_before_create(dir_path)
|
|
256
|
+
dir_path = os.path.realpath(dir_path)
|
|
244
257
|
if os.path.isdir(dir_path):
|
|
245
258
|
return
|
|
246
259
|
try:
|
|
@@ -262,8 +275,9 @@ def create_directory(dir_path):
|
|
|
262
275
|
Exception Description:
|
|
263
276
|
when invalid data throw exception
|
|
264
277
|
"""
|
|
265
|
-
|
|
278
|
+
check_link(dir_path)
|
|
266
279
|
check_path_before_create(dir_path)
|
|
280
|
+
dir_path = os.path.realpath(dir_path)
|
|
267
281
|
parent_dir = os.path.dirname(dir_path)
|
|
268
282
|
if not os.path.isdir(parent_dir):
|
|
269
283
|
create_directory(parent_dir)
|
|
@@ -271,6 +285,7 @@ def create_directory(dir_path):
|
|
|
271
285
|
|
|
272
286
|
|
|
273
287
|
def check_path_before_create(path):
|
|
288
|
+
check_link(path)
|
|
274
289
|
if path_len_exceeds_limit(path):
|
|
275
290
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
|
|
276
291
|
|
|
@@ -279,6 +294,17 @@ def check_path_before_create(path):
|
|
|
279
294
|
'The file path {} contains special characters.'.format(path))
|
|
280
295
|
|
|
281
296
|
|
|
297
|
+
def check_dirpath_before_read(path):
|
|
298
|
+
path = os.path.realpath(path)
|
|
299
|
+
dirpath = os.path.dirname(path)
|
|
300
|
+
if check_others_writable(dirpath):
|
|
301
|
+
logger.warning(f"The directory is writable by others: {dirpath}.")
|
|
302
|
+
try:
|
|
303
|
+
check_path_owner_consistent(dirpath)
|
|
304
|
+
except FileCheckException:
|
|
305
|
+
logger.warning(f"The directory {dirpath} is not yours.")
|
|
306
|
+
|
|
307
|
+
|
|
282
308
|
def check_file_or_directory_path(path, isdir=False):
|
|
283
309
|
"""
|
|
284
310
|
Function Description:
|
|
@@ -344,7 +370,7 @@ def load_yaml(yaml_path):
|
|
|
344
370
|
def load_npy(filepath):
|
|
345
371
|
check_file_or_directory_path(filepath)
|
|
346
372
|
try:
|
|
347
|
-
npy = np.load(filepath)
|
|
373
|
+
npy = np.load(filepath, allow_pickle=False)
|
|
348
374
|
except Exception as e:
|
|
349
375
|
logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
|
|
350
376
|
raise RuntimeError(f"Load numpy file {filepath} failed.") from e
|
|
@@ -354,7 +380,7 @@ def load_npy(filepath):
|
|
|
354
380
|
def load_json(json_path):
|
|
355
381
|
try:
|
|
356
382
|
with FileOpen(json_path, "r") as f:
|
|
357
|
-
fcntl.flock(f, fcntl.
|
|
383
|
+
fcntl.flock(f, fcntl.LOCK_SH)
|
|
358
384
|
data = json.load(f)
|
|
359
385
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
360
386
|
except Exception as e:
|
|
@@ -363,11 +389,11 @@ def load_json(json_path):
|
|
|
363
389
|
return data
|
|
364
390
|
|
|
365
391
|
|
|
366
|
-
def save_json(json_path, data, indent=None):
|
|
367
|
-
json_path = os.path.realpath(json_path)
|
|
392
|
+
def save_json(json_path, data, indent=None, mode="w"):
|
|
368
393
|
check_path_before_create(json_path)
|
|
394
|
+
json_path = os.path.realpath(json_path)
|
|
369
395
|
try:
|
|
370
|
-
with FileOpen(json_path,
|
|
396
|
+
with FileOpen(json_path, mode) as f:
|
|
371
397
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
372
398
|
json.dump(data, f, indent=indent)
|
|
373
399
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
@@ -378,8 +404,8 @@ def save_json(json_path, data, indent=None):
|
|
|
378
404
|
|
|
379
405
|
|
|
380
406
|
def save_yaml(yaml_path, data):
|
|
381
|
-
yaml_path = os.path.realpath(yaml_path)
|
|
382
407
|
check_path_before_create(yaml_path)
|
|
408
|
+
yaml_path = os.path.realpath(yaml_path)
|
|
383
409
|
try:
|
|
384
410
|
with FileOpen(yaml_path, 'w') as f:
|
|
385
411
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
@@ -391,6 +417,37 @@ def save_yaml(yaml_path, data):
|
|
|
391
417
|
change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
392
418
|
|
|
393
419
|
|
|
420
|
+
def save_excel(path, data):
|
|
421
|
+
def validate_data(data):
|
|
422
|
+
"""Validate that the data is a DataFrame or a list of (DataFrame, sheet_name) pairs."""
|
|
423
|
+
if isinstance(data, pd.DataFrame):
|
|
424
|
+
return "single"
|
|
425
|
+
elif isinstance(data, list):
|
|
426
|
+
if all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], pd.DataFrame) for item in data):
|
|
427
|
+
return "list"
|
|
428
|
+
raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
|
|
429
|
+
|
|
430
|
+
check_path_before_create(path)
|
|
431
|
+
path = os.path.realpath(path)
|
|
432
|
+
|
|
433
|
+
# 验证数据类型
|
|
434
|
+
data_type = validate_data(data)
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
if data_type == "single":
|
|
438
|
+
data.to_excel(path, index=False)
|
|
439
|
+
elif data_type == "list":
|
|
440
|
+
with pd.ExcelWriter(path) as writer:
|
|
441
|
+
for data_df, sheet_name in data:
|
|
442
|
+
data_df.to_excel(writer, sheet_name=sheet_name, index=False)
|
|
443
|
+
except Exception as e:
|
|
444
|
+
logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
|
|
445
|
+
raise RuntimeError(f"Save excel file {path} failed.") from e
|
|
446
|
+
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
|
|
394
451
|
def move_file(src_path, dst_path):
|
|
395
452
|
check_file_or_directory_path(src_path)
|
|
396
453
|
check_path_before_create(dst_path)
|
|
@@ -403,8 +460,8 @@ def move_file(src_path, dst_path):
|
|
|
403
460
|
|
|
404
461
|
|
|
405
462
|
def save_npy(data, filepath):
|
|
406
|
-
filepath = os.path.realpath(filepath)
|
|
407
463
|
check_path_before_create(filepath)
|
|
464
|
+
filepath = os.path.realpath(filepath)
|
|
408
465
|
try:
|
|
409
466
|
np.save(filepath, data)
|
|
410
467
|
except Exception as e:
|
|
@@ -425,6 +482,7 @@ def save_npy_to_txt(data, dst_file='', align=0):
|
|
|
425
482
|
pad_array = np.zeros((align - data.size % align,))
|
|
426
483
|
data = np.append(data, pad_array)
|
|
427
484
|
check_path_before_create(dst_file)
|
|
485
|
+
dst_file = os.path.realpath(dst_file)
|
|
428
486
|
try:
|
|
429
487
|
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
430
488
|
except Exception as e:
|
|
@@ -438,8 +496,8 @@ def save_workbook(workbook, file_path):
|
|
|
438
496
|
workbook: 要保存的工作簿对象
|
|
439
497
|
file_path: 文件保存路径
|
|
440
498
|
"""
|
|
441
|
-
file_path = os.path.realpath(file_path)
|
|
442
499
|
check_path_before_create(file_path)
|
|
500
|
+
file_path = os.path.realpath(file_path)
|
|
443
501
|
try:
|
|
444
502
|
workbook.save(file_path)
|
|
445
503
|
except Exception as e:
|
|
@@ -451,7 +509,7 @@ def save_workbook(workbook, file_path):
|
|
|
451
509
|
def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
452
510
|
def csv_value_is_valid(value: str) -> bool:
|
|
453
511
|
if not isinstance(value, str):
|
|
454
|
-
return True
|
|
512
|
+
return True
|
|
455
513
|
try:
|
|
456
514
|
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
457
515
|
float(value)
|
|
@@ -459,16 +517,16 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
459
517
|
# otherwise, they will be considered as formular injections
|
|
460
518
|
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
461
519
|
return True
|
|
462
|
-
|
|
520
|
+
|
|
463
521
|
if malicious_check:
|
|
464
522
|
for row in data:
|
|
465
523
|
for cell in row:
|
|
466
524
|
if not csv_value_is_valid(cell):
|
|
467
|
-
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
525
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
468
526
|
f"to be written into the csv: {filepath}.")
|
|
469
527
|
|
|
470
|
-
file_path = os.path.realpath(filepath)
|
|
471
528
|
check_path_before_create(filepath)
|
|
529
|
+
file_path = os.path.realpath(filepath)
|
|
472
530
|
try:
|
|
473
531
|
with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
|
|
474
532
|
writer = csv.writer(f)
|
|
@@ -479,16 +537,54 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
479
537
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
480
538
|
|
|
481
539
|
|
|
482
|
-
def read_csv(filepath):
|
|
540
|
+
def read_csv(filepath, as_pd=True, header='infer'):
|
|
483
541
|
check_file_or_directory_path(filepath)
|
|
484
542
|
try:
|
|
485
|
-
|
|
543
|
+
if as_pd:
|
|
544
|
+
csv_data = pd.read_csv(filepath, header=header)
|
|
545
|
+
else:
|
|
546
|
+
with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
|
|
547
|
+
csv_reader = csv.reader(f, delimiter=',')
|
|
548
|
+
csv_data = list(csv_reader)
|
|
486
549
|
except Exception as e:
|
|
487
550
|
logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
|
|
488
551
|
raise RuntimeError(f"Read csv file {filepath} failed.") from e
|
|
489
552
|
return csv_data
|
|
490
553
|
|
|
491
554
|
|
|
555
|
+
def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False):
|
|
556
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
557
|
+
if not isinstance(value, str):
|
|
558
|
+
return True
|
|
559
|
+
try:
|
|
560
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
561
|
+
float(value)
|
|
562
|
+
except ValueError:
|
|
563
|
+
# otherwise, they will be considered as formular injections
|
|
564
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
565
|
+
return True
|
|
566
|
+
|
|
567
|
+
if not isinstance(data, pd.DataFrame):
|
|
568
|
+
raise ValueError("The data type of data is not supported. Only support pd.DataFrame.")
|
|
569
|
+
|
|
570
|
+
if malicious_check:
|
|
571
|
+
for i in range(len(data)):
|
|
572
|
+
for j in range(len(data.columns)):
|
|
573
|
+
cell = data.iloc[i, j]
|
|
574
|
+
if not csv_value_is_valid(cell):
|
|
575
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
576
|
+
f"to be written into the csv: {filepath}.")
|
|
577
|
+
|
|
578
|
+
check_path_before_create(filepath)
|
|
579
|
+
file_path = os.path.realpath(filepath)
|
|
580
|
+
try:
|
|
581
|
+
data.to_csv(filepath, mode=mode, header=header, index=False)
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
|
|
584
|
+
raise RuntimeError(f"Save csv file {file_path} failed.") from e
|
|
585
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
586
|
+
|
|
587
|
+
|
|
492
588
|
def remove_path(path):
|
|
493
589
|
if not os.path.exists(path):
|
|
494
590
|
return
|
|
@@ -521,3 +617,57 @@ def get_json_contents(file_path):
|
|
|
521
617
|
def get_file_content_bytes(file):
|
|
522
618
|
with FileOpen(file, 'rb') as file_handle:
|
|
523
619
|
return file_handle.read()
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# 对os.walk设置遍历深度
|
|
623
|
+
def os_walk_for_files(path, depth):
|
|
624
|
+
res = []
|
|
625
|
+
for root, _, files in os.walk(path, topdown=True):
|
|
626
|
+
check_file_or_directory_path(root, isdir=True)
|
|
627
|
+
if root.count(os.sep) - path.count(os.sep) >= depth:
|
|
628
|
+
_[:] = []
|
|
629
|
+
else:
|
|
630
|
+
for file in files:
|
|
631
|
+
res.append({"file": file, "root": root})
|
|
632
|
+
return res
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def check_crt_valid(pem_path):
|
|
636
|
+
"""
|
|
637
|
+
Check the validity of the SSL certificate.
|
|
638
|
+
|
|
639
|
+
Load the SSL certificate from the specified path, parse and check its validity period.
|
|
640
|
+
If the certificate is expired or invalid, raise a RuntimeError.
|
|
641
|
+
|
|
642
|
+
Parameters:
|
|
643
|
+
pem_path (str): The file path of the SSL certificate.
|
|
644
|
+
|
|
645
|
+
Raises:
|
|
646
|
+
RuntimeError: If the SSL certificate is invalid or expired.
|
|
647
|
+
"""
|
|
648
|
+
import OpenSSL
|
|
649
|
+
try:
|
|
650
|
+
with FileOpen(pem_path, "r") as f:
|
|
651
|
+
pem_data = f.read()
|
|
652
|
+
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
|
|
653
|
+
pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
|
|
654
|
+
pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
|
|
655
|
+
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
656
|
+
f"starts from {pem_start} ends at {pem_end}.")
|
|
657
|
+
except Exception as e:
|
|
658
|
+
logger.error("Failed to parse the SSL certificate. Check the certificate.")
|
|
659
|
+
raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
|
|
660
|
+
|
|
661
|
+
now_utc = datetime.now(tz=timezone.utc)
|
|
662
|
+
if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
663
|
+
raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def read_xlsx(file_path):
|
|
667
|
+
check_file_or_directory_path(file_path)
|
|
668
|
+
try:
|
|
669
|
+
result_df = pd.read_excel(file_path, keep_default_na=False)
|
|
670
|
+
except Exception as e:
|
|
671
|
+
logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
|
|
672
|
+
raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
|
|
673
|
+
return result_df
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from msprobe.core.common.file_utils import load_yaml
|
|
3
18
|
|
msprobe/core/common/log.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import time
|
|
3
18
|
import sys
|
|
@@ -5,6 +20,16 @@ from functools import wraps
|
|
|
5
20
|
from msprobe.core.common.const import MsgConst
|
|
6
21
|
|
|
7
22
|
|
|
23
|
+
def filter_special_chars(func):
|
|
24
|
+
@wraps(func)
|
|
25
|
+
def func_level(self, msg, **kwargs):
|
|
26
|
+
for char in MsgConst.SPECIAL_CHAR:
|
|
27
|
+
msg = msg.replace(char, '_')
|
|
28
|
+
return func(self, msg, **kwargs)
|
|
29
|
+
|
|
30
|
+
return func_level
|
|
31
|
+
|
|
32
|
+
|
|
8
33
|
class BaseLogger:
|
|
9
34
|
def __init__(self):
|
|
10
35
|
self.rank = None
|
|
@@ -21,14 +46,6 @@ class BaseLogger:
|
|
|
21
46
|
def get_rank(self):
|
|
22
47
|
return self.rank
|
|
23
48
|
|
|
24
|
-
def filter_special_chars(func):
|
|
25
|
-
@wraps(func)
|
|
26
|
-
def func_level(self, msg, **kwargs):
|
|
27
|
-
for char in MsgConst.SPECIAL_CHAR:
|
|
28
|
-
msg = msg.replace(char, '_')
|
|
29
|
-
return func(self, msg, **kwargs)
|
|
30
|
-
return func_level
|
|
31
|
-
|
|
32
49
|
@filter_special_chars
|
|
33
50
|
def error(self, msg):
|
|
34
51
|
if self.level <= MsgConst.LogLevel.ERROR.value:
|
|
@@ -56,6 +73,7 @@ class BaseLogger:
|
|
|
56
73
|
return func(*args, **kwargs)
|
|
57
74
|
else:
|
|
58
75
|
return None
|
|
76
|
+
|
|
59
77
|
return func_rank_0
|
|
60
78
|
|
|
61
79
|
def info_on_rank_0(self, msg):
|
|
@@ -66,7 +84,7 @@ class BaseLogger:
|
|
|
66
84
|
|
|
67
85
|
def warning_on_rank_0(self, msg):
|
|
68
86
|
return self.on_rank_0(self.warning)(msg)
|
|
69
|
-
|
|
87
|
+
|
|
70
88
|
def error_log_with_exp(self, msg, exception):
|
|
71
89
|
self.error(msg)
|
|
72
90
|
raise exception
|