mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,28 +12,32 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import collections
|
|
18
17
|
import os
|
|
19
18
|
import re
|
|
20
19
|
import subprocess
|
|
21
20
|
import time
|
|
22
|
-
import
|
|
21
|
+
from collections import defaultdict
|
|
23
22
|
from datetime import datetime, timezone
|
|
23
|
+
from functools import wraps
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
26
28
|
from msprobe.core.common.const import Const, CompareConst
|
|
27
29
|
from msprobe.core.common.log import logger
|
|
28
|
-
|
|
30
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
29
31
|
|
|
30
32
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
31
33
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
class
|
|
36
|
+
class MsprobeBaseException(Exception):
|
|
35
37
|
"""
|
|
36
|
-
|
|
38
|
+
Base class for all custom exceptions.
|
|
37
39
|
"""
|
|
40
|
+
# 所有的错误代码
|
|
38
41
|
NONE_ERROR = 0
|
|
39
42
|
INVALID_PATH_ERROR = 1
|
|
40
43
|
OPEN_FILE_ERROR = 2
|
|
@@ -57,10 +60,20 @@ class CompareException(Exception):
|
|
|
57
60
|
INVALID_SUMMARY_MODE = 19
|
|
58
61
|
INVALID_TASK_ERROR = 20
|
|
59
62
|
DETACH_ERROR = 21
|
|
60
|
-
|
|
63
|
+
INVALID_OBJECT_TYPE_ERROR = 22
|
|
64
|
+
INVALID_CHAR_ERROR = 23
|
|
65
|
+
RECURSION_LIMIT_ERROR = 24
|
|
66
|
+
INVALID_ATTRIBUTE_ERROR = 25
|
|
67
|
+
OUTPUT_HOOK_ERROR = 26
|
|
68
|
+
INPUT_HOOK_ERROR = 27
|
|
69
|
+
FUNCTION_CALL_ERROR = 28
|
|
70
|
+
FORWARD_DATA_COLLECTION_ERROR = 29
|
|
71
|
+
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
72
|
+
INVALID_KEY_ERROR = 31
|
|
73
|
+
MISSING_HEADER_ERROR = 32
|
|
61
74
|
|
|
62
75
|
def __init__(self, code, error_info: str = ""):
|
|
63
|
-
super(
|
|
76
|
+
super(MsprobeBaseException, self).__init__()
|
|
64
77
|
self.code = code
|
|
65
78
|
self.error_info = error_info
|
|
66
79
|
|
|
@@ -68,80 +81,55 @@ class CompareException(Exception):
|
|
|
68
81
|
return self.error_info
|
|
69
82
|
|
|
70
83
|
|
|
71
|
-
class
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def check_mode_valid(mode, scope=None, api_list=None):
|
|
76
|
-
if scope is None:
|
|
77
|
-
scope = []
|
|
78
|
-
if api_list is None:
|
|
79
|
-
api_list = []
|
|
80
|
-
if not isinstance(scope, list):
|
|
81
|
-
raise ValueError("scope param set invalid, it's must be a list.")
|
|
82
|
-
if not isinstance(api_list, list):
|
|
83
|
-
raise ValueError("api_list param set invalid, it's must be a list.")
|
|
84
|
-
mode_check = {
|
|
85
|
-
Const.ALL: lambda: None,
|
|
86
|
-
Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
|
|
87
|
-
Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
|
|
88
|
-
Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
|
|
89
|
-
Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
|
|
90
|
-
Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
|
|
91
|
-
Const.API_STACK: lambda: None,
|
|
92
|
-
}
|
|
93
|
-
if mode not in Const.DUMP_MODE:
|
|
94
|
-
msg = "Current mode '%s' is not supported. Please use the field in %s" % \
|
|
95
|
-
(mode, Const.DUMP_MODE)
|
|
96
|
-
raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
|
|
97
|
-
|
|
98
|
-
if mode_check.get(mode)() is not None:
|
|
99
|
-
raise mode_check.get(mode)()
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def check_switch_valid(switch):
|
|
103
|
-
if switch not in ["ON", "OFF"]:
|
|
104
|
-
logger.error("Please set switch with 'ON' or 'OFF'.")
|
|
105
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
84
|
+
class CompareException(MsprobeBaseException):
|
|
85
|
+
"""
|
|
86
|
+
Class for Accuracy Compare Exception
|
|
87
|
+
"""
|
|
106
88
|
|
|
89
|
+
def __init__(self, code, error_info: str = ""):
|
|
90
|
+
super(CompareException, self).__init__(code, error_info)
|
|
107
91
|
|
|
108
|
-
def check_dump_mode_valid(dump_mode):
|
|
109
|
-
if not isinstance(dump_mode, list):
|
|
110
|
-
logger.warning("Please set dump_mode as a list.")
|
|
111
|
-
dump_mode = [dump_mode]
|
|
112
|
-
if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
|
|
113
|
-
raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
|
|
114
|
-
if 'input' not in dump_mode and 'output' not in dump_mode:
|
|
115
|
-
dump_mode.extend(['input', 'output'])
|
|
116
|
-
if 'forward' not in dump_mode and 'backward' not in dump_mode:
|
|
117
|
-
dump_mode.extend(['forward', 'backward'])
|
|
118
|
-
if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
|
|
119
|
-
return ["forward", "backward", "input", "output"]
|
|
120
|
-
return dump_mode
|
|
121
92
|
|
|
93
|
+
class DumpException(MsprobeBaseException):
|
|
94
|
+
"""
|
|
95
|
+
Class for Dump Exception
|
|
96
|
+
"""
|
|
122
97
|
|
|
123
|
-
def
|
|
124
|
-
|
|
125
|
-
msg = "The summary_mode is not valid"
|
|
126
|
-
raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
|
|
98
|
+
def __init__(self, code, error_info: str = ""):
|
|
99
|
+
super(DumpException, self).__init__(code, error_info)
|
|
127
100
|
|
|
101
|
+
def __str__(self):
|
|
102
|
+
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
128
103
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
104
|
+
|
|
105
|
+
def is_json_file(file_path):
|
|
106
|
+
if isinstance(file_path, str) and file_path.lower().endswith('.json'):
|
|
107
|
+
return True
|
|
108
|
+
else:
|
|
109
|
+
return False
|
|
134
110
|
|
|
135
111
|
|
|
136
|
-
def check_compare_param(input_param, output_path,
|
|
137
|
-
if not
|
|
138
|
-
logger.error("Invalid input
|
|
112
|
+
def check_compare_param(input_param, output_path, dump_mode):
|
|
113
|
+
if not isinstance(input_param, dict):
|
|
114
|
+
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
139
115
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
116
|
+
if not isinstance(output_path, str):
|
|
117
|
+
logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
|
|
118
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
119
|
+
|
|
120
|
+
def check_json_path(json_path_str):
|
|
121
|
+
json_path = input_param.get(json_path_str)
|
|
122
|
+
check_file_or_directory_path(json_path, False)
|
|
123
|
+
json_type_check = is_json_file(json_path)
|
|
124
|
+
if not json_type_check:
|
|
125
|
+
logger.error(f"Invalid {json_path_str}: {json_path}, please check!")
|
|
126
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
127
|
+
|
|
128
|
+
check_json_path("npu_json_path")
|
|
129
|
+
check_json_path("bench_json_path")
|
|
130
|
+
check_json_path("stack_json_path")
|
|
140
131
|
|
|
141
|
-
|
|
142
|
-
check_file_or_directory_path(input_param.get("bench_json_path"), False)
|
|
143
|
-
check_file_or_directory_path(input_param.get("stack_json_path"), False)
|
|
144
|
-
if not summary_compare and not md5_compare:
|
|
132
|
+
if dump_mode == Const.ALL:
|
|
145
133
|
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
146
134
|
check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
|
|
147
135
|
check_file_or_directory_path(output_path, True)
|
|
@@ -152,15 +140,12 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
|
|
|
152
140
|
check_json_file(input_param, npu_json, bench_json, stack_json)
|
|
153
141
|
|
|
154
142
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def is_starts_with(string, prefix_list):
|
|
163
|
-
return any(string.startswith(prefix) for prefix in prefix_list)
|
|
143
|
+
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
144
|
+
arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
|
|
145
|
+
for arg in arg_list:
|
|
146
|
+
if not isinstance(arg, bool):
|
|
147
|
+
logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
|
|
148
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
164
149
|
|
|
165
150
|
|
|
166
151
|
def _check_json(json_file_handle, file_name):
|
|
@@ -198,28 +183,6 @@ def check_regex_prefix_format_valid(prefix):
|
|
|
198
183
|
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
199
184
|
|
|
200
185
|
|
|
201
|
-
def get_dump_data_path(dump_dir):
|
|
202
|
-
"""
|
|
203
|
-
Function Description:
|
|
204
|
-
traverse directories and obtain the absolute path of dump data
|
|
205
|
-
Parameter:
|
|
206
|
-
dump_dir: dump data directory
|
|
207
|
-
Return Value:
|
|
208
|
-
dump data path,file is exist or file is not exist
|
|
209
|
-
"""
|
|
210
|
-
dump_data_path = None
|
|
211
|
-
file_is_exist = False
|
|
212
|
-
|
|
213
|
-
check_file_or_directory_path(dump_dir, True)
|
|
214
|
-
for dir_path, _, files in os.walk(dump_dir):
|
|
215
|
-
if len(files) != 0:
|
|
216
|
-
dump_data_path = dir_path
|
|
217
|
-
file_is_exist = True
|
|
218
|
-
break
|
|
219
|
-
dump_data_path = dir_path
|
|
220
|
-
return dump_data_path, file_is_exist
|
|
221
|
-
|
|
222
|
-
|
|
223
186
|
def execute_command(cmd):
|
|
224
187
|
"""
|
|
225
188
|
Function Description:
|
|
@@ -235,28 +198,12 @@ def execute_command(cmd):
|
|
|
235
198
|
line = process.stdout.readline()
|
|
236
199
|
line = line.strip()
|
|
237
200
|
if line:
|
|
238
|
-
|
|
201
|
+
logger.info(line)
|
|
239
202
|
if process.returncode != 0:
|
|
240
203
|
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
241
204
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
242
205
|
|
|
243
206
|
|
|
244
|
-
def parse_value_by_comma(value):
|
|
245
|
-
"""
|
|
246
|
-
parse value by comma, like '1,2,4,8'
|
|
247
|
-
"""
|
|
248
|
-
value_list = []
|
|
249
|
-
value_str_list = value.split(Const.COMMA)
|
|
250
|
-
for value_str in value_str_list:
|
|
251
|
-
value_str = value_str.strip()
|
|
252
|
-
if value_str.isdigit() or value_str == '-1':
|
|
253
|
-
value_list.append(int(value_str))
|
|
254
|
-
else:
|
|
255
|
-
logger.error("please check your input shape.")
|
|
256
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
257
|
-
return value_list
|
|
258
|
-
|
|
259
|
-
|
|
260
207
|
def add_time_as_suffix(name):
|
|
261
208
|
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
262
209
|
|
|
@@ -265,6 +212,10 @@ def add_time_with_xlsx(name):
|
|
|
265
212
|
return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
266
213
|
|
|
267
214
|
|
|
215
|
+
def add_time_with_yaml(name):
|
|
216
|
+
return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
217
|
+
|
|
218
|
+
|
|
268
219
|
def get_time():
|
|
269
220
|
return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
270
221
|
|
|
@@ -273,61 +224,6 @@ def format_value(value):
|
|
|
273
224
|
return float('{:.12f}'.format(value))
|
|
274
225
|
|
|
275
226
|
|
|
276
|
-
def check_seed_all(seed, mode):
|
|
277
|
-
if isinstance(seed, int):
|
|
278
|
-
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
279
|
-
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
280
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
281
|
-
else:
|
|
282
|
-
logger.error(f"Seed must be integer.")
|
|
283
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
284
|
-
if not isinstance(mode, bool):
|
|
285
|
-
logger.error(f"seed_all mode must be bool.")
|
|
286
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def get_process_rank(model):
|
|
290
|
-
logger.info("Rank id is not provided. Trying to get the rank id of the model.")
|
|
291
|
-
try:
|
|
292
|
-
local_device = next(model.parameters()).device
|
|
293
|
-
except StopIteration:
|
|
294
|
-
logger.warning('There is no parameter in the model. Fail to get rank id.')
|
|
295
|
-
return 0, False
|
|
296
|
-
if local_device.type == 'cpu':
|
|
297
|
-
logger.warning("Warning: the debugger is unable to get the rank id. "
|
|
298
|
-
"This may cause the dumpped data to be corrupted in the "
|
|
299
|
-
"case of distributed training. (You may ignore this if you are using only one card.) "
|
|
300
|
-
"Transfer the model to npu or gpu before register_hook() to avoid this warning.")
|
|
301
|
-
return 0, False
|
|
302
|
-
else:
|
|
303
|
-
return local_device.index, True
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
|
|
307
|
-
template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
|
|
308
|
-
pkl_dir = os.path.dirname(pkl_file_path)
|
|
309
|
-
compare_script_path = os.path.join(pkl_dir, "compare_data.py")
|
|
310
|
-
is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
|
|
311
|
-
|
|
312
|
-
try:
|
|
313
|
-
with FileOpen(template_path, 'r') as ftemp, \
|
|
314
|
-
os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
|
|
315
|
-
code_temp = ftemp.read()
|
|
316
|
-
fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
|
|
317
|
-
except OSError:
|
|
318
|
-
logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
|
|
319
|
-
|
|
320
|
-
logger.info(f"Generate compare script successfully which is {compare_script_path}.")
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def check_inplace_op(prefix):
|
|
324
|
-
if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
|
|
325
|
-
return False
|
|
326
|
-
match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
|
|
327
|
-
op_name = match_op[0] if match_op else None
|
|
328
|
-
return op_name in Const.INPLACE_LIST
|
|
329
|
-
|
|
330
|
-
|
|
331
227
|
def md5_find(data):
|
|
332
228
|
for key_op in data:
|
|
333
229
|
for api_info in data[key_op]:
|
|
@@ -335,46 +231,89 @@ def md5_find(data):
|
|
|
335
231
|
for data_detail in data[key_op][api_info]:
|
|
336
232
|
if data_detail and 'md5' in data_detail:
|
|
337
233
|
return True
|
|
338
|
-
elif 'md5' in data[key_op][api_info]:
|
|
234
|
+
elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
|
|
339
235
|
return True
|
|
340
236
|
return False
|
|
341
237
|
|
|
342
238
|
|
|
343
|
-
def
|
|
239
|
+
def detect_framework_by_dump_json(file_path):
|
|
240
|
+
pattern_ms = r'"type":\s*"mindspore'
|
|
241
|
+
pattern_pt = r'"type":\s*"torch'
|
|
242
|
+
with FileOpen(file_path, 'r') as file:
|
|
243
|
+
for line in file:
|
|
244
|
+
if re.search(pattern_ms, line):
|
|
245
|
+
return Const.MS_FRAMEWORK
|
|
246
|
+
if re.search(pattern_pt, line):
|
|
247
|
+
return Const.PT_FRAMEWORK
|
|
248
|
+
logger.error(f"{file_path} must be based on the MindSpore or PyTorch framework.")
|
|
249
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
253
|
+
if not dump_json_path:
|
|
254
|
+
logger.error("The path is empty. Please enter a valid path.")
|
|
255
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
256
|
+
directory = os.path.dirname(dump_json_path)
|
|
257
|
+
check_file_or_directory_path(directory, True)
|
|
258
|
+
stack_json = os.path.join(directory, "stack.json")
|
|
259
|
+
construct_json = os.path.join(directory, "construct.json")
|
|
260
|
+
|
|
261
|
+
stack = load_json(stack_json)
|
|
262
|
+
construct = load_json(construct_json)
|
|
263
|
+
return stack, construct
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def set_dump_path(input_param):
|
|
344
267
|
npu_path = input_param.get("npu_json_path", None)
|
|
345
268
|
bench_path = input_param.get("bench_json_path", None)
|
|
346
|
-
|
|
347
|
-
|
|
269
|
+
npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
|
|
270
|
+
bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
|
|
271
|
+
if not npu_path_valid or not bench_path_valid:
|
|
272
|
+
logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
|
|
348
273
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
274
|
+
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
275
|
+
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_dump_mode(input_param):
|
|
279
|
+
npu_path = input_param.get("npu_json_path", None)
|
|
280
|
+
bench_path = input_param.get("bench_json_path", None)
|
|
281
|
+
npu_json_data = load_json(npu_path)
|
|
282
|
+
bench_json_data = load_json(bench_path)
|
|
283
|
+
|
|
284
|
+
npu_task = npu_json_data.get('task', None)
|
|
285
|
+
bench_task = bench_json_data.get('task', None)
|
|
286
|
+
|
|
287
|
+
if not npu_task or not bench_task:
|
|
288
|
+
logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
|
|
289
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
290
|
+
|
|
291
|
+
if npu_task != bench_task:
|
|
354
292
|
logger.error(f"Please check the dump task is consistent.")
|
|
355
293
|
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
294
|
+
|
|
295
|
+
if npu_task == Const.TENSOR:
|
|
296
|
+
return Const.ALL
|
|
297
|
+
|
|
298
|
+
if npu_task == Const.STATISTICS:
|
|
299
|
+
npu_md5_compare = md5_find(npu_json_data['data'])
|
|
300
|
+
bench_md5_compare = md5_find(bench_json_data['data'])
|
|
301
|
+
if npu_md5_compare == bench_md5_compare:
|
|
302
|
+
return Const.MD5 if npu_md5_compare else Const.SUMMARY
|
|
363
303
|
else:
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
368
|
-
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
369
|
-
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
370
|
-
return summary_compare, md5_compare
|
|
304
|
+
logger.error(f"Please check the dump task is consistent, "
|
|
305
|
+
f"dump mode of npu and bench should both be statistics or md5.")
|
|
306
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
371
307
|
|
|
308
|
+
logger.error(f"Compare applies only to task is tensor or statistics")
|
|
309
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
372
310
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
311
|
+
|
|
312
|
+
def get_header_index(header_name, dump_mode):
|
|
313
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE.get(dump_mode)
|
|
314
|
+
if not header:
|
|
315
|
+
logger.error(f"{dump_mode} not in {CompareConst.HEAD_OF_COMPARE_MODE}")
|
|
316
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
378
317
|
if header_name not in header:
|
|
379
318
|
logger.error(f"{header_name} not in data name")
|
|
380
319
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -382,4 +321,164 @@ def get_header_index(header_name, summary_compare=False):
|
|
|
382
321
|
|
|
383
322
|
|
|
384
323
|
def convert_tuple(data):
|
|
385
|
-
return data if isinstance(data, tuple) else (data,
|
|
324
|
+
return data if isinstance(data, tuple) else (data,)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def check_op_str_pattern_valid(string, op_name=None, stack=False):
|
|
328
|
+
if isinstance(string, str) and is_invalid_pattern(string):
|
|
329
|
+
if stack:
|
|
330
|
+
message = f"stack info of {op_name} contains special characters, please check!"
|
|
331
|
+
elif not op_name:
|
|
332
|
+
message = f"api name contains special characters, please check!"
|
|
333
|
+
else:
|
|
334
|
+
message = f"data info of {op_name} contains special characters, please check!"
|
|
335
|
+
logger.error(message)
|
|
336
|
+
raise CompareException(CompareException.INVALID_CHAR_ERROR)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def is_invalid_pattern(string):
|
|
340
|
+
pattern = Const.STRING_BLACKLIST
|
|
341
|
+
return re.search(pattern, string)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def is_int(x):
|
|
345
|
+
return isinstance(x, int) and not isinstance(x, bool)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def print_tools_ends_info():
|
|
349
|
+
total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
350
|
+
logger.info('*' * total_len)
|
|
351
|
+
logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
352
|
+
logger.info('*' * total_len)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def get_step_or_rank_from_string(step_or_rank, obj):
|
|
356
|
+
splited = step_or_rank.split(Const.HYPHEN)
|
|
357
|
+
if len(splited) == 2:
|
|
358
|
+
try:
|
|
359
|
+
borderlines = int(splited[0]), int(splited[1])
|
|
360
|
+
except (ValueError, IndexError) as e:
|
|
361
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
362
|
+
"The hyphen(-) must start and end with decimal numbers.") from e
|
|
363
|
+
else:
|
|
364
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
365
|
+
f'The string parameter for {obj} only supports formats like "3-5". '
|
|
366
|
+
f'Now string parameter for {obj} is "{step_or_rank}".')
|
|
367
|
+
if all(Const.STEP_RANK_MINIMUM_VALUE <= b <= Const.STEP_RANK_MAXIMUM_VALUE for b in borderlines):
|
|
368
|
+
if borderlines[0] <= borderlines[1]:
|
|
369
|
+
continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
|
|
370
|
+
else:
|
|
371
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
372
|
+
f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be '
|
|
373
|
+
f'greater than the right boundary ({borderlines[1]}).')
|
|
374
|
+
else:
|
|
375
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
376
|
+
f"The boundaries must fall within the range of "
|
|
377
|
+
f"[{Const.STEP_RANK_MINIMUM_VALUE}, {Const.STEP_RANK_MAXIMUM_VALUE}].")
|
|
378
|
+
return continual_step_or_rank
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def get_real_step_or_rank(step_or_rank_input, obj):
|
|
382
|
+
if obj not in [Const.STEP, Const.RANK]:
|
|
383
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
384
|
+
f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
|
|
385
|
+
if step_or_rank_input is None:
|
|
386
|
+
return []
|
|
387
|
+
if not isinstance(step_or_rank_input, list):
|
|
388
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
|
|
389
|
+
if len(step_or_rank_input) > Const.STEP_RANK_MAXIMUM_VALUE:
|
|
390
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
391
|
+
f"{obj} is invalid, its length cannot exceed {Const.STEP_RANK_MAXIMUM_VALUE}")
|
|
392
|
+
|
|
393
|
+
real_step_or_rank = []
|
|
394
|
+
for element in step_or_rank_input:
|
|
395
|
+
if not is_int(element) and not isinstance(element, str):
|
|
396
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
397
|
+
f"{obj} element {element} must be an integer or string.")
|
|
398
|
+
if isinstance(element, int) and element < 0:
|
|
399
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
400
|
+
f"Each element of {obj} must be non-negative, currently it is {element}.")
|
|
401
|
+
if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
|
|
402
|
+
real_step_or_rank.append(element)
|
|
403
|
+
elif isinstance(element, str) and Const.HYPHEN in element:
|
|
404
|
+
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
405
|
+
real_step_or_rank.extend(continual_step_or_rank)
|
|
406
|
+
real_step_or_rank = list(set(real_step_or_rank))
|
|
407
|
+
real_step_or_rank.sort()
|
|
408
|
+
return real_step_or_rank
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def check_seed_all(seed, mode):
|
|
412
|
+
if is_int(seed):
|
|
413
|
+
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
414
|
+
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
415
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
416
|
+
else:
|
|
417
|
+
logger.error("Seed must be integer.")
|
|
418
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
419
|
+
if not isinstance(mode, bool):
|
|
420
|
+
logger.error("seed_all mode must be bool.")
|
|
421
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def safe_get_value(container, index, container_name, key=None):
|
|
425
|
+
try:
|
|
426
|
+
# 处理字典情况
|
|
427
|
+
if isinstance(container, dict):
|
|
428
|
+
return container.get(key)[index]
|
|
429
|
+
# 处理列表、元组、numpy情况
|
|
430
|
+
elif isinstance(container, (list, tuple, np.ndarray)):
|
|
431
|
+
return container[index]
|
|
432
|
+
else:
|
|
433
|
+
err_msg = f"Unsupported container type for '{container_name}': {type(container)}"
|
|
434
|
+
logger.error(err_msg)
|
|
435
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR)
|
|
436
|
+
except IndexError as e:
|
|
437
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
438
|
+
f"{container_name} is {container}\n" \
|
|
439
|
+
f"index is {index}"
|
|
440
|
+
logger.error(err_msg)
|
|
441
|
+
raise MsprobeBaseException(MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
442
|
+
except TypeError as e:
|
|
443
|
+
err_msg = "wrong type, please check!\n" \
|
|
444
|
+
f"{container_name} is {container}\n" \
|
|
445
|
+
f"index is {index}\n" \
|
|
446
|
+
f"key is {key}"
|
|
447
|
+
logger.error(err_msg)
|
|
448
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
# 记录工具函数递归的深度
|
|
452
|
+
recursion_depth = defaultdict(int)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
|
|
456
|
+
def recursion_depth_decorator(func_info):
|
|
457
|
+
def decorator(func):
|
|
458
|
+
@wraps(func)
|
|
459
|
+
def wrapper(*args, **kwargs):
|
|
460
|
+
func_id = id(func)
|
|
461
|
+
recursion_depth[func_id] += 1
|
|
462
|
+
if recursion_depth[func_id] > Const.MAX_DEPTH:
|
|
463
|
+
msg = f"call {func_info} exceeds the recursion limit."
|
|
464
|
+
logger.error_log_with_exp(
|
|
465
|
+
msg,
|
|
466
|
+
MsprobeException(
|
|
467
|
+
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
try:
|
|
471
|
+
result = func(*args, **kwargs)
|
|
472
|
+
finally:
|
|
473
|
+
recursion_depth[func_id] -= 1
|
|
474
|
+
return result
|
|
475
|
+
|
|
476
|
+
return wrapper
|
|
477
|
+
|
|
478
|
+
return decorator
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def check_str_param(param):
|
|
482
|
+
if not re.match(Const.REGEX_PREFIX_PATTERN, param):
|
|
483
|
+
logger.error('The parameter {} contains special characters.'.format(param))
|
|
484
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
|