mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
msprobe/core/common/const.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import stat
|
|
3
18
|
|
|
@@ -10,6 +25,7 @@ class Const:
|
|
|
10
25
|
"""
|
|
11
26
|
TOOL_NAME = "msprobe"
|
|
12
27
|
|
|
28
|
+
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
13
29
|
SEP = "."
|
|
14
30
|
REGEX_PREFIX_MAX_LENGTH = 20
|
|
15
31
|
REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
|
|
@@ -20,6 +36,8 @@ class Const:
|
|
|
20
36
|
OFF = 'OFF'
|
|
21
37
|
BACKWARD = 'backward'
|
|
22
38
|
FORWARD = 'forward'
|
|
39
|
+
PROGRESS_TIMEOUT = 3000
|
|
40
|
+
EXCEPTION_NONE = None
|
|
23
41
|
JIT = 'Jit'
|
|
24
42
|
PRIMITIVE_PREFIX = 'Primitive'
|
|
25
43
|
DEFAULT_LIST = []
|
|
@@ -82,6 +100,7 @@ class Const:
|
|
|
82
100
|
GRAD_PROBE = "grad_probe"
|
|
83
101
|
TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
|
|
84
102
|
DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
|
|
103
|
+
DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
|
|
85
104
|
LEVEL_L0 = "L0"
|
|
86
105
|
LEVEL_L1 = "L1"
|
|
87
106
|
LEVEL_L2 = "L2"
|
|
@@ -93,6 +112,7 @@ class Const:
|
|
|
93
112
|
DATA = "data"
|
|
94
113
|
PT_FRAMEWORK = "pytorch"
|
|
95
114
|
MS_FRAMEWORK = "mindspore"
|
|
115
|
+
UNKNOWN_FRAMEWORK = "unknown"
|
|
96
116
|
DIRECTORY_LENGTH = 4096
|
|
97
117
|
FILE_NAME_LENGTH = 255
|
|
98
118
|
FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
|
|
@@ -103,6 +123,8 @@ class Const:
|
|
|
103
123
|
CPU_LOWERCASE = 'cpu'
|
|
104
124
|
CUDA_LOWERCASE = 'cuda'
|
|
105
125
|
DISTRIBUTED = 'Distributed'
|
|
126
|
+
DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
|
|
127
|
+
"Aten", "VF", "NPU", "Jit"]
|
|
106
128
|
|
|
107
129
|
# struct json param
|
|
108
130
|
ORIGIN_DATA = "origin_data"
|
|
@@ -113,21 +135,25 @@ class Const:
|
|
|
113
135
|
MODULE_WHITE_LIST = ["torch", "numpy"]
|
|
114
136
|
|
|
115
137
|
FUNC_SKIP_LIST = ["construct", "__call__"]
|
|
116
|
-
|
|
117
|
-
|
|
138
|
+
FILE_SKIP_LIST = ["msprobe", "MindSpeed"]
|
|
139
|
+
DATA_TYPE_SKIP_LIST = ["Primitive", "Jit"]
|
|
118
140
|
|
|
119
141
|
STACK_FILE_INDEX = 0
|
|
120
|
-
|
|
121
142
|
STACK_FUNC_INDEX = 2
|
|
122
|
-
|
|
123
143
|
STACK_FUNC_ELE_INDEX = 1
|
|
124
144
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
145
|
+
SCOPE_ID_INDEX = -1
|
|
146
|
+
SCOPE_DIRECTION_INDEX = -2
|
|
147
|
+
TYPE_NAME_INDEX = -3
|
|
148
|
+
LAYER_NAME_INDEX = -4
|
|
149
|
+
API_TYPE_INDEX = 0
|
|
150
|
+
LEFT_MOVE_INDEX = -1
|
|
151
|
+
RIGHT_MOVE_INDEX = 1
|
|
152
|
+
|
|
153
|
+
TOP_LAYER = "TopLayer"
|
|
154
|
+
CELL = "Cell"
|
|
155
|
+
MODULE = "Module"
|
|
156
|
+
FRAME_FILE_LIST = ["site-packages/torch", "package/torch", "site-packages/mindspore", "package/mindspore"]
|
|
131
157
|
INPLACE_LIST = [
|
|
132
158
|
"broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
|
|
133
159
|
"_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all",
|
|
@@ -145,11 +171,12 @@ class Const:
|
|
|
145
171
|
FILL_CHAR_NUMS = 50
|
|
146
172
|
TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
|
|
147
173
|
WITHOUT_CALL_STACK = "The call stack retrieval failed."
|
|
148
|
-
|
|
174
|
+
|
|
149
175
|
STEP = "step"
|
|
150
176
|
RANK = "rank"
|
|
151
177
|
HYPHEN = "-"
|
|
152
|
-
|
|
178
|
+
STEP_RANK_MINIMUM_VALUE = 0
|
|
179
|
+
STEP_RANK_MAXIMUM_VALUE = int(1e6)
|
|
153
180
|
|
|
154
181
|
# data type const
|
|
155
182
|
FLOAT16 = "Float16"
|
|
@@ -159,6 +186,13 @@ class Const:
|
|
|
159
186
|
TORCH_FLOAT32 = "torch.float32"
|
|
160
187
|
TORCH_BFLOAT16 = "torch.bfloat16"
|
|
161
188
|
|
|
189
|
+
DTYPE = 'dtype'
|
|
190
|
+
SHAPE = 'shape'
|
|
191
|
+
MAX = 'Max'
|
|
192
|
+
MIN = 'Min'
|
|
193
|
+
MEAN = 'Mean'
|
|
194
|
+
NORM = 'Norm'
|
|
195
|
+
|
|
162
196
|
|
|
163
197
|
class CompareConst:
|
|
164
198
|
"""
|
|
@@ -201,10 +235,17 @@ class CompareConst:
|
|
|
201
235
|
RESULT = "Result"
|
|
202
236
|
MAGNITUDE = 0.5
|
|
203
237
|
OP_NAME = "op_name"
|
|
238
|
+
STRUCT = "struct"
|
|
204
239
|
INPUT_STRUCT = "input_struct"
|
|
240
|
+
KWARGS_STRUCT = "kwargs_struct"
|
|
205
241
|
OUTPUT_STRUCT = "output_struct"
|
|
206
242
|
SUMMARY = "summary"
|
|
207
243
|
MAX_EXCEL_LENGTH = 1048576
|
|
244
|
+
YES = "Yes"
|
|
245
|
+
NO = "No"
|
|
246
|
+
STATISTICS_INDICATOR_NUM = 4
|
|
247
|
+
EPSILON = 1e-10
|
|
248
|
+
COMPARE_ENDS_SUCCESSFULLY = "msprobe compare ends successfully."
|
|
208
249
|
|
|
209
250
|
COMPARE_RESULT_HEADER = [
|
|
210
251
|
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
|
|
@@ -222,6 +263,12 @@ class CompareConst:
|
|
|
222
263
|
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
|
|
223
264
|
]
|
|
224
265
|
|
|
266
|
+
HEAD_OF_COMPARE_MODE = {
|
|
267
|
+
Const.ALL: COMPARE_RESULT_HEADER,
|
|
268
|
+
Const.SUMMARY: SUMMARY_COMPARE_RESULT_HEADER,
|
|
269
|
+
Const.MD5: MD5_COMPARE_RESULT_HEADER
|
|
270
|
+
}
|
|
271
|
+
|
|
225
272
|
# compare standard
|
|
226
273
|
HUNDRED_RATIO_THRESHOLD = 0.01
|
|
227
274
|
THOUSAND_RATIO_THRESHOLD = 0.001
|
|
@@ -241,6 +288,8 @@ class CompareConst:
|
|
|
241
288
|
PASS = 'pass'
|
|
242
289
|
WARNING = 'Warning'
|
|
243
290
|
ERROR = 'error'
|
|
291
|
+
TRUE = 'TRUE'
|
|
292
|
+
FALSE = 'FALSE'
|
|
244
293
|
SKIP = 'SKIP'
|
|
245
294
|
N_A = 'N/A'
|
|
246
295
|
INF = 'inf'
|
|
@@ -298,6 +347,13 @@ class CompareConst:
|
|
|
298
347
|
MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
|
|
299
348
|
MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
|
|
300
349
|
}
|
|
350
|
+
INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
|
|
351
|
+
KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
|
|
352
|
+
OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
|
|
353
|
+
COMPARE_KEY = 'compare_key'
|
|
354
|
+
COMPARE_SHAPE = 'compare_shape'
|
|
355
|
+
INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
|
|
356
|
+
UNREADABLE = 'unreadable data'
|
|
301
357
|
|
|
302
358
|
|
|
303
359
|
class FileCheckConst:
|
|
@@ -322,7 +378,7 @@ class FileCheckConst:
|
|
|
322
378
|
MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
323
379
|
MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
324
380
|
MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
325
|
-
MAX_YAML_SIZE =
|
|
381
|
+
MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
326
382
|
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
327
383
|
DIR = "dir"
|
|
328
384
|
FILE = "file"
|
|
@@ -351,6 +407,9 @@ class MsCompareConst:
|
|
|
351
407
|
# api_info field
|
|
352
408
|
MINT = "Mint"
|
|
353
409
|
MINT_FUNCTIONAL = "MintFunctional"
|
|
410
|
+
TENSOR_API = "Tensor"
|
|
411
|
+
|
|
412
|
+
API_NAME_STR_LENGTH = 4
|
|
354
413
|
|
|
355
414
|
TASK_FIELD = "task"
|
|
356
415
|
STATISTICS_TASK = "statistics"
|
|
@@ -358,6 +417,10 @@ class MsCompareConst:
|
|
|
358
417
|
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
359
418
|
DATA_FIELD = "data"
|
|
360
419
|
|
|
420
|
+
# supported api yaml
|
|
421
|
+
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
422
|
+
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
423
|
+
|
|
361
424
|
# detail_csv
|
|
362
425
|
DETAIL_CSV_API_NAME = "API Name"
|
|
363
426
|
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
@@ -382,15 +445,20 @@ class MsgConst:
|
|
|
382
445
|
MSPROBE_LOG_LEVEL = "MSPROBE_LOG_LEVEL"
|
|
383
446
|
LOG_LEVEL_ENUM = ["0", "1", "2", "3", "4"]
|
|
384
447
|
LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
|
|
448
|
+
|
|
385
449
|
class LogLevel:
|
|
386
450
|
class DEBUG:
|
|
387
451
|
value = 0
|
|
452
|
+
|
|
388
453
|
class INFO:
|
|
389
454
|
value = 1
|
|
455
|
+
|
|
390
456
|
class WARNING:
|
|
391
457
|
value = 2
|
|
458
|
+
|
|
392
459
|
class ERROR:
|
|
393
460
|
value = 3
|
|
461
|
+
|
|
394
462
|
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
395
463
|
|
|
396
464
|
NOT_CREATED_INSTANCE = "PrecisionDebugger instance is not created."
|
|
@@ -400,3 +468,35 @@ class GraphMode:
|
|
|
400
468
|
NPY_MODE = "NPY_MODE"
|
|
401
469
|
STATISTIC_MODE = "STATISTIC_MODE"
|
|
402
470
|
ERROR_MODE = "ERROR_MODE"
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
class MonitorConst:
|
|
474
|
+
"""
|
|
475
|
+
Class for monitor const
|
|
476
|
+
"""
|
|
477
|
+
OP_LIST = ["min", "max", "norm", "zeros", "nans", "id", "mean"]
|
|
478
|
+
MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
|
|
479
|
+
DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
|
|
480
|
+
DATABASE = "database"
|
|
481
|
+
EMAIL = "email"
|
|
482
|
+
OPT_TY = ['Megatron_DistributedOptimizer', 'Megatron_Float16OptimizerWithFloat16Params']
|
|
483
|
+
DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage0", "DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3")
|
|
484
|
+
RULE_NAME = ['AnomalyTurbulence']
|
|
485
|
+
|
|
486
|
+
DOT = "."
|
|
487
|
+
VPP_SEP = ":"
|
|
488
|
+
ACTV_IN = "input"
|
|
489
|
+
ACTV_OUT = "output"
|
|
490
|
+
ACTVGRAD_IN = "input_grad"
|
|
491
|
+
ACTVGRAD_OUT = "output_grad"
|
|
492
|
+
POST_GRAD = "post_grad"
|
|
493
|
+
PRE_GRAD = "pre_grad"
|
|
494
|
+
PREFIX_POST = "post"
|
|
495
|
+
PREFIX_PRE = "pre"
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
ANOMALY_JSON = "anomaly.json"
|
|
499
|
+
ANALYSE_JSON = "anomaly_analyse.json"
|
|
500
|
+
TENSORBOARD = "tensorboard"
|
|
501
|
+
CSV = "csv"
|
|
502
|
+
API = "api"
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
class CodedException(Exception):
|
|
2
17
|
def __init__(self, code, error_info=''):
|
|
3
18
|
super().__init__()
|
|
@@ -11,10 +26,12 @@ class CodedException(Exception):
|
|
|
11
26
|
class MsprobeException(CodedException):
|
|
12
27
|
INVALID_PARAM_ERROR = 0
|
|
13
28
|
OVERFLOW_NUMS_ERROR = 1
|
|
29
|
+
RECURSION_LIMIT_ERROR = 2
|
|
14
30
|
|
|
15
31
|
err_strs = {
|
|
16
32
|
INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
|
|
17
|
-
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
|
|
33
|
+
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
|
|
34
|
+
RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:"
|
|
18
35
|
}
|
|
19
36
|
|
|
20
37
|
|
|
@@ -41,7 +58,7 @@ class ParseJsonException(CodedException):
|
|
|
41
58
|
InvalidDumpJson = 1
|
|
42
59
|
err_strs = {
|
|
43
60
|
UnexpectedNameStruct: "[msprobe] Unexpected name in json: ",
|
|
44
|
-
InvalidDumpJson: "[msprobe] json
|
|
61
|
+
InvalidDumpJson: "[msprobe] Invalid dump.json format: ",
|
|
45
62
|
}
|
|
46
63
|
|
|
47
64
|
|
|
@@ -73,9 +90,13 @@ class StepException(CodedException):
|
|
|
73
90
|
class FreeBenchmarkException(CodedException):
|
|
74
91
|
UnsupportedType = 0
|
|
75
92
|
InvalidGrad = 1
|
|
93
|
+
InvalidPerturbedOutput = 2
|
|
94
|
+
OutputIndexError = 3
|
|
76
95
|
err_strs = {
|
|
77
96
|
UnsupportedType: "[msprobe] Free benchmark get unsupported type: ",
|
|
78
97
|
InvalidGrad: "[msprobe] Free benchmark gradient invalid: ",
|
|
98
|
+
InvalidPerturbedOutput: "[msprobe] Free benchmark invalid perturbed output: ",
|
|
99
|
+
OutputIndexError: "[msprobe] Free benchmark output index out of bounds: ",
|
|
79
100
|
}
|
|
80
101
|
|
|
81
102
|
|
|
@@ -87,6 +108,7 @@ class DistributedNotInitializedError(Exception):
|
|
|
87
108
|
def __str__(self):
|
|
88
109
|
return self.msg
|
|
89
110
|
|
|
111
|
+
|
|
90
112
|
class ApiAccuracyCheckerException(CodedException):
|
|
91
113
|
ParseJsonFailed = 0
|
|
92
114
|
UnsupportType = 1
|
|
@@ -97,4 +119,4 @@ class ApiAccuracyCheckerException(CodedException):
|
|
|
97
119
|
UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ",
|
|
98
120
|
WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ",
|
|
99
121
|
ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ",
|
|
100
|
-
}
|
|
122
|
+
}
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,13 +12,17 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import csv
|
|
18
17
|
import fcntl
|
|
19
18
|
import os
|
|
19
|
+
import stat
|
|
20
20
|
import json
|
|
21
21
|
import re
|
|
22
22
|
import shutil
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
|
+
from dateutil import parser
|
|
25
|
+
import OpenSSL
|
|
23
26
|
import yaml
|
|
24
27
|
import numpy as np
|
|
25
28
|
import pandas as pd
|
|
@@ -67,9 +70,11 @@ class FileChecker:
|
|
|
67
70
|
self.check_path_ability()
|
|
68
71
|
if self.is_script:
|
|
69
72
|
check_path_owner_consistent(self.file_path)
|
|
70
|
-
|
|
73
|
+
check_path_pattern_valid(self.file_path)
|
|
71
74
|
check_common_file_size(self.file_path)
|
|
72
75
|
check_file_suffix(self.file_path, self.file_type)
|
|
76
|
+
if self.path_type == FileCheckConst.FILE:
|
|
77
|
+
check_dirpath_before_read(self.file_path)
|
|
73
78
|
return self.file_path
|
|
74
79
|
|
|
75
80
|
def check_path_ability(self):
|
|
@@ -122,9 +127,10 @@ class FileOpen:
|
|
|
122
127
|
self.file_path = os.path.realpath(self.file_path)
|
|
123
128
|
check_path_length(self.file_path)
|
|
124
129
|
self.check_ability_and_owner()
|
|
125
|
-
|
|
130
|
+
check_path_pattern_valid(self.file_path)
|
|
126
131
|
if os.path.exists(self.file_path):
|
|
127
132
|
check_common_file_size(self.file_path)
|
|
133
|
+
check_dirpath_before_read(self.file_path)
|
|
128
134
|
|
|
129
135
|
def check_ability_and_owner(self):
|
|
130
136
|
if self.mode in self.SUPPORT_READ_MODE:
|
|
@@ -193,7 +199,7 @@ def check_path_owner_consistent(path):
|
|
|
193
199
|
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
194
200
|
|
|
195
201
|
|
|
196
|
-
def
|
|
202
|
+
def check_path_pattern_valid(path):
|
|
197
203
|
if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
|
|
198
204
|
logger.error('The file path %s contains special characters.' % (path))
|
|
199
205
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
@@ -217,7 +223,6 @@ def check_common_file_size(file_path):
|
|
|
217
223
|
check_file_size(file_path, max_size)
|
|
218
224
|
return
|
|
219
225
|
check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
|
|
220
|
-
|
|
221
226
|
|
|
222
227
|
|
|
223
228
|
def check_file_suffix(file_path, file_suffix):
|
|
@@ -238,9 +243,18 @@ def check_path_type(file_path, file_type):
|
|
|
238
243
|
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
239
244
|
|
|
240
245
|
|
|
246
|
+
def check_others_writable(directory):
|
|
247
|
+
dir_stat = os.stat(directory)
|
|
248
|
+
is_writable = (
|
|
249
|
+
bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
|
|
250
|
+
bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
|
|
251
|
+
)
|
|
252
|
+
return is_writable
|
|
253
|
+
|
|
254
|
+
|
|
241
255
|
def make_dir(dir_path):
|
|
242
|
-
dir_path = os.path.realpath(dir_path)
|
|
243
256
|
check_path_before_create(dir_path)
|
|
257
|
+
dir_path = os.path.realpath(dir_path)
|
|
244
258
|
if os.path.isdir(dir_path):
|
|
245
259
|
return
|
|
246
260
|
try:
|
|
@@ -262,8 +276,9 @@ def create_directory(dir_path):
|
|
|
262
276
|
Exception Description:
|
|
263
277
|
when invalid data throw exception
|
|
264
278
|
"""
|
|
265
|
-
|
|
279
|
+
check_link(dir_path)
|
|
266
280
|
check_path_before_create(dir_path)
|
|
281
|
+
dir_path = os.path.realpath(dir_path)
|
|
267
282
|
parent_dir = os.path.dirname(dir_path)
|
|
268
283
|
if not os.path.isdir(parent_dir):
|
|
269
284
|
create_directory(parent_dir)
|
|
@@ -271,6 +286,7 @@ def create_directory(dir_path):
|
|
|
271
286
|
|
|
272
287
|
|
|
273
288
|
def check_path_before_create(path):
|
|
289
|
+
check_link(path)
|
|
274
290
|
if path_len_exceeds_limit(path):
|
|
275
291
|
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
|
|
276
292
|
|
|
@@ -279,6 +295,17 @@ def check_path_before_create(path):
|
|
|
279
295
|
'The file path {} contains special characters.'.format(path))
|
|
280
296
|
|
|
281
297
|
|
|
298
|
+
def check_dirpath_before_read(path):
|
|
299
|
+
path = os.path.realpath(path)
|
|
300
|
+
dirpath = os.path.dirname(path)
|
|
301
|
+
if check_others_writable(dirpath):
|
|
302
|
+
logger.warning(f"The directory is writable by others: {dirpath}.")
|
|
303
|
+
try:
|
|
304
|
+
check_path_owner_consistent(dirpath)
|
|
305
|
+
except FileCheckException:
|
|
306
|
+
logger.warning(f"The directory {dirpath} is not yours.")
|
|
307
|
+
|
|
308
|
+
|
|
282
309
|
def check_file_or_directory_path(path, isdir=False):
|
|
283
310
|
"""
|
|
284
311
|
Function Description:
|
|
@@ -344,7 +371,7 @@ def load_yaml(yaml_path):
|
|
|
344
371
|
def load_npy(filepath):
|
|
345
372
|
check_file_or_directory_path(filepath)
|
|
346
373
|
try:
|
|
347
|
-
npy = np.load(filepath)
|
|
374
|
+
npy = np.load(filepath, allow_pickle=False)
|
|
348
375
|
except Exception as e:
|
|
349
376
|
logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
|
|
350
377
|
raise RuntimeError(f"Load numpy file {filepath} failed.") from e
|
|
@@ -354,7 +381,7 @@ def load_npy(filepath):
|
|
|
354
381
|
def load_json(json_path):
|
|
355
382
|
try:
|
|
356
383
|
with FileOpen(json_path, "r") as f:
|
|
357
|
-
fcntl.flock(f, fcntl.
|
|
384
|
+
fcntl.flock(f, fcntl.LOCK_SH)
|
|
358
385
|
data = json.load(f)
|
|
359
386
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
360
387
|
except Exception as e:
|
|
@@ -363,11 +390,11 @@ def load_json(json_path):
|
|
|
363
390
|
return data
|
|
364
391
|
|
|
365
392
|
|
|
366
|
-
def save_json(json_path, data, indent=None):
|
|
367
|
-
json_path = os.path.realpath(json_path)
|
|
393
|
+
def save_json(json_path, data, indent=None, mode="w"):
|
|
368
394
|
check_path_before_create(json_path)
|
|
395
|
+
json_path = os.path.realpath(json_path)
|
|
369
396
|
try:
|
|
370
|
-
with FileOpen(json_path,
|
|
397
|
+
with FileOpen(json_path, mode) as f:
|
|
371
398
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
372
399
|
json.dump(data, f, indent=indent)
|
|
373
400
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
@@ -378,8 +405,8 @@ def save_json(json_path, data, indent=None):
|
|
|
378
405
|
|
|
379
406
|
|
|
380
407
|
def save_yaml(yaml_path, data):
|
|
381
|
-
yaml_path = os.path.realpath(yaml_path)
|
|
382
408
|
check_path_before_create(yaml_path)
|
|
409
|
+
yaml_path = os.path.realpath(yaml_path)
|
|
383
410
|
try:
|
|
384
411
|
with FileOpen(yaml_path, 'w') as f:
|
|
385
412
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
@@ -391,6 +418,21 @@ def save_yaml(yaml_path, data):
|
|
|
391
418
|
change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
392
419
|
|
|
393
420
|
|
|
421
|
+
def save_excel(path, data):
|
|
422
|
+
check_path_before_create(path)
|
|
423
|
+
path = os.path.realpath(path)
|
|
424
|
+
try:
|
|
425
|
+
if isinstance(data, pd.DataFrame):
|
|
426
|
+
data.to_excel(path, index=False)
|
|
427
|
+
else:
|
|
428
|
+
logger.error(f'unsupported data type.')
|
|
429
|
+
return
|
|
430
|
+
except Exception as e:
|
|
431
|
+
logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
|
|
432
|
+
raise RuntimeError(f"Save excel file {path} failed.") from e
|
|
433
|
+
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
434
|
+
|
|
435
|
+
|
|
394
436
|
def move_file(src_path, dst_path):
|
|
395
437
|
check_file_or_directory_path(src_path)
|
|
396
438
|
check_path_before_create(dst_path)
|
|
@@ -403,8 +445,8 @@ def move_file(src_path, dst_path):
|
|
|
403
445
|
|
|
404
446
|
|
|
405
447
|
def save_npy(data, filepath):
|
|
406
|
-
filepath = os.path.realpath(filepath)
|
|
407
448
|
check_path_before_create(filepath)
|
|
449
|
+
filepath = os.path.realpath(filepath)
|
|
408
450
|
try:
|
|
409
451
|
np.save(filepath, data)
|
|
410
452
|
except Exception as e:
|
|
@@ -425,6 +467,7 @@ def save_npy_to_txt(data, dst_file='', align=0):
|
|
|
425
467
|
pad_array = np.zeros((align - data.size % align,))
|
|
426
468
|
data = np.append(data, pad_array)
|
|
427
469
|
check_path_before_create(dst_file)
|
|
470
|
+
dst_file = os.path.realpath(dst_file)
|
|
428
471
|
try:
|
|
429
472
|
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
430
473
|
except Exception as e:
|
|
@@ -438,8 +481,8 @@ def save_workbook(workbook, file_path):
|
|
|
438
481
|
workbook: 要保存的工作簿对象
|
|
439
482
|
file_path: 文件保存路径
|
|
440
483
|
"""
|
|
441
|
-
file_path = os.path.realpath(file_path)
|
|
442
484
|
check_path_before_create(file_path)
|
|
485
|
+
file_path = os.path.realpath(file_path)
|
|
443
486
|
try:
|
|
444
487
|
workbook.save(file_path)
|
|
445
488
|
except Exception as e:
|
|
@@ -451,7 +494,7 @@ def save_workbook(workbook, file_path):
|
|
|
451
494
|
def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
452
495
|
def csv_value_is_valid(value: str) -> bool:
|
|
453
496
|
if not isinstance(value, str):
|
|
454
|
-
return True
|
|
497
|
+
return True
|
|
455
498
|
try:
|
|
456
499
|
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
457
500
|
float(value)
|
|
@@ -459,16 +502,16 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
459
502
|
# otherwise, they will be considered as formular injections
|
|
460
503
|
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
461
504
|
return True
|
|
462
|
-
|
|
505
|
+
|
|
463
506
|
if malicious_check:
|
|
464
507
|
for row in data:
|
|
465
508
|
for cell in row:
|
|
466
509
|
if not csv_value_is_valid(cell):
|
|
467
|
-
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
510
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
468
511
|
f"to be written into the csv: {filepath}.")
|
|
469
512
|
|
|
470
|
-
file_path = os.path.realpath(filepath)
|
|
471
513
|
check_path_before_create(filepath)
|
|
514
|
+
file_path = os.path.realpath(filepath)
|
|
472
515
|
try:
|
|
473
516
|
with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
|
|
474
517
|
writer = csv.writer(f)
|
|
@@ -479,16 +522,54 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
479
522
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
480
523
|
|
|
481
524
|
|
|
482
|
-
def read_csv(filepath):
|
|
525
|
+
def read_csv(filepath, as_pd=True):
|
|
483
526
|
check_file_or_directory_path(filepath)
|
|
484
527
|
try:
|
|
485
|
-
|
|
528
|
+
if as_pd:
|
|
529
|
+
csv_data = pd.read_csv(filepath)
|
|
530
|
+
else:
|
|
531
|
+
with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
|
|
532
|
+
csv_reader = csv.reader(f, delimiter=',')
|
|
533
|
+
csv_data = list(csv_reader)
|
|
486
534
|
except Exception as e:
|
|
487
535
|
logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
|
|
488
536
|
raise RuntimeError(f"Read csv file {filepath} failed.") from e
|
|
489
537
|
return csv_data
|
|
490
538
|
|
|
491
539
|
|
|
540
|
+
def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False):
|
|
541
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
542
|
+
if not isinstance(value, str):
|
|
543
|
+
return True
|
|
544
|
+
try:
|
|
545
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
546
|
+
float(value)
|
|
547
|
+
except ValueError:
|
|
548
|
+
# otherwise, they will be considered as formular injections
|
|
549
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
550
|
+
return True
|
|
551
|
+
|
|
552
|
+
if not isinstance(data, pd.DataFrame):
|
|
553
|
+
raise ValueError("The data type of data is not supported. Only support pd.DataFrame.")
|
|
554
|
+
|
|
555
|
+
if malicious_check:
|
|
556
|
+
for i in range(len(data)):
|
|
557
|
+
for j in range(len(data.columns)):
|
|
558
|
+
cell = data.iloc[i, j]
|
|
559
|
+
if not csv_value_is_valid(cell):
|
|
560
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
561
|
+
f"to be written into the csv: {filepath}.")
|
|
562
|
+
|
|
563
|
+
check_path_before_create(filepath)
|
|
564
|
+
file_path = os.path.realpath(filepath)
|
|
565
|
+
try:
|
|
566
|
+
data.to_csv(filepath, mode=mode, header=header, index=False)
|
|
567
|
+
except Exception as e:
|
|
568
|
+
logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
|
|
569
|
+
raise RuntimeError(f"Save csv file {file_path} failed.") from e
|
|
570
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
571
|
+
|
|
572
|
+
|
|
492
573
|
def remove_path(path):
|
|
493
574
|
if not os.path.exists(path):
|
|
494
575
|
return
|
|
@@ -521,3 +602,46 @@ def get_json_contents(file_path):
|
|
|
521
602
|
def get_file_content_bytes(file):
|
|
522
603
|
with FileOpen(file, 'rb') as file_handle:
|
|
523
604
|
return file_handle.read()
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
# 对os.walk设置遍历深度
|
|
608
|
+
def os_walk_for_files(path, depth):
|
|
609
|
+
res = []
|
|
610
|
+
for root, _, files in os.walk(path, topdown=True):
|
|
611
|
+
check_file_or_directory_path(root, isdir=True)
|
|
612
|
+
if root.count(os.sep) - path.count(os.sep) >= depth:
|
|
613
|
+
_[:] = []
|
|
614
|
+
else:
|
|
615
|
+
for file in files:
|
|
616
|
+
res.append({"file": file, "root": root})
|
|
617
|
+
return res
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def check_crt_valid(pem_path):
|
|
621
|
+
"""
|
|
622
|
+
Check the validity of the SSL certificate.
|
|
623
|
+
|
|
624
|
+
Load the SSL certificate from the specified path, parse and check its validity period.
|
|
625
|
+
If the certificate is expired or invalid, raise a RuntimeError.
|
|
626
|
+
|
|
627
|
+
Parameters:
|
|
628
|
+
pem_path (str): The file path of the SSL certificate.
|
|
629
|
+
|
|
630
|
+
Raises:
|
|
631
|
+
RuntimeError: If the SSL certificate is invalid or expired.
|
|
632
|
+
"""
|
|
633
|
+
try:
|
|
634
|
+
with FileOpen(pem_path, "r") as f:
|
|
635
|
+
pem_data = f.read()
|
|
636
|
+
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
|
|
637
|
+
pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
|
|
638
|
+
pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
|
|
639
|
+
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
640
|
+
f"starts from {pem_start} ends at {pem_end}.")
|
|
641
|
+
except Exception as e:
|
|
642
|
+
logger.error("Failed to parse the SSL certificate. Check the certificate.")
|
|
643
|
+
raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
|
|
644
|
+
|
|
645
|
+
now_utc = datetime.now(tz=timezone.utc)
|
|
646
|
+
if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
647
|
+
raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from msprobe.core.common.file_utils import load_yaml
|
|
3
18
|
|