mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -22,19 +22,21 @@ import time
|
|
|
22
22
|
import json
|
|
23
23
|
from datetime import datetime, timezone
|
|
24
24
|
|
|
25
|
-
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path)
|
|
25
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
26
26
|
from msprobe.core.common.const import Const, CompareConst
|
|
27
27
|
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
31
32
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
class
|
|
35
|
+
class MsprobeBaseException(Exception):
|
|
35
36
|
"""
|
|
36
|
-
|
|
37
|
+
Base class for all custom exceptions.
|
|
37
38
|
"""
|
|
39
|
+
# 所有的错误代码
|
|
38
40
|
NONE_ERROR = 0
|
|
39
41
|
INVALID_PATH_ERROR = 1
|
|
40
42
|
OPEN_FILE_ERROR = 2
|
|
@@ -57,10 +59,18 @@ class CompareException(Exception):
|
|
|
57
59
|
INVALID_SUMMARY_MODE = 19
|
|
58
60
|
INVALID_TASK_ERROR = 20
|
|
59
61
|
DETACH_ERROR = 21
|
|
60
|
-
|
|
62
|
+
INVALID_OBJECT_TYPE_ERROR = 22
|
|
63
|
+
INVALID_CHAR_ERROR = 23
|
|
64
|
+
RECURSION_LIMIT_ERROR = 24
|
|
65
|
+
INVALID_ATTRIBUTE_ERROR = 25
|
|
66
|
+
OUTPUT_HOOK_ERROR = 26
|
|
67
|
+
INPUT_HOOK_ERROR = 27
|
|
68
|
+
FUNCTION_CALL_ERROR = 28
|
|
69
|
+
FORWARD_DATA_COLLECTION_ERROR = 29
|
|
70
|
+
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
61
71
|
|
|
62
72
|
def __init__(self, code, error_info: str = ""):
|
|
63
|
-
super(
|
|
73
|
+
super(MsprobeBaseException, self).__init__()
|
|
64
74
|
self.code = code
|
|
65
75
|
self.error_info = error_info
|
|
66
76
|
|
|
@@ -68,74 +78,33 @@ class CompareException(Exception):
|
|
|
68
78
|
return self.error_info
|
|
69
79
|
|
|
70
80
|
|
|
71
|
-
class
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def check_mode_valid(mode, scope=None, api_list=None):
|
|
76
|
-
if scope is None:
|
|
77
|
-
scope = []
|
|
78
|
-
if api_list is None:
|
|
79
|
-
api_list = []
|
|
80
|
-
if not isinstance(scope, list):
|
|
81
|
-
raise ValueError("scope param set invalid, it's must be a list.")
|
|
82
|
-
if not isinstance(api_list, list):
|
|
83
|
-
raise ValueError("api_list param set invalid, it's must be a list.")
|
|
84
|
-
mode_check = {
|
|
85
|
-
Const.ALL: lambda: None,
|
|
86
|
-
Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
|
|
87
|
-
Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
|
|
88
|
-
Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
|
|
89
|
-
Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
|
|
90
|
-
Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
|
|
91
|
-
Const.API_STACK: lambda: None,
|
|
92
|
-
}
|
|
93
|
-
if mode not in Const.DUMP_MODE:
|
|
94
|
-
msg = "Current mode '%s' is not supported. Please use the field in %s" % \
|
|
95
|
-
(mode, Const.DUMP_MODE)
|
|
96
|
-
raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
|
|
97
|
-
|
|
98
|
-
if mode_check.get(mode)() is not None:
|
|
99
|
-
raise mode_check.get(mode)()
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def check_switch_valid(switch):
|
|
103
|
-
if switch not in ["ON", "OFF"]:
|
|
104
|
-
logger.error("Please set switch with 'ON' or 'OFF'.")
|
|
105
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
106
|
-
|
|
81
|
+
class CompareException(MsprobeBaseException):
|
|
82
|
+
"""
|
|
83
|
+
Class for Accuracy Compare Exception
|
|
84
|
+
"""
|
|
107
85
|
|
|
108
|
-
def
|
|
109
|
-
|
|
110
|
-
logger.warning("Please set dump_mode as a list.")
|
|
111
|
-
dump_mode = [dump_mode]
|
|
112
|
-
if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
|
|
113
|
-
raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
|
|
114
|
-
if 'input' not in dump_mode and 'output' not in dump_mode:
|
|
115
|
-
dump_mode.extend(['input', 'output'])
|
|
116
|
-
if 'forward' not in dump_mode and 'backward' not in dump_mode:
|
|
117
|
-
dump_mode.extend(['forward', 'backward'])
|
|
118
|
-
if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
|
|
119
|
-
return ["forward", "backward", "input", "output"]
|
|
120
|
-
return dump_mode
|
|
86
|
+
def __init__(self, code, error_info: str = ""):
|
|
87
|
+
super(CompareException, self).__init__(code, error_info)
|
|
121
88
|
|
|
122
89
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
90
|
+
class DumpException(MsprobeBaseException):
|
|
91
|
+
"""
|
|
92
|
+
Class for Dump Exception
|
|
93
|
+
"""
|
|
127
94
|
|
|
95
|
+
def __init__(self, code, error_info: str = ""):
|
|
96
|
+
super(DumpException, self).__init__(code, error_info)
|
|
128
97
|
|
|
129
|
-
def
|
|
130
|
-
|
|
131
|
-
logger.error("Params summary_only only support True or False.")
|
|
132
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
133
|
-
return summary_only
|
|
98
|
+
def __str__(self):
|
|
99
|
+
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
134
100
|
|
|
135
101
|
|
|
136
102
|
def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
|
|
137
|
-
if not
|
|
138
|
-
logger.error("Invalid input
|
|
103
|
+
if not isinstance(input_param, dict):
|
|
104
|
+
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
105
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
106
|
+
if not isinstance(output_path, str):
|
|
107
|
+
logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
|
|
139
108
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
140
109
|
|
|
141
110
|
check_file_or_directory_path(input_param.get("npu_json_path"), False)
|
|
@@ -152,15 +121,12 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
|
|
|
152
121
|
check_json_file(input_param, npu_json, bench_json, stack_json)
|
|
153
122
|
|
|
154
123
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def is_starts_with(string, prefix_list):
|
|
163
|
-
return any(string.startswith(prefix) for prefix in prefix_list)
|
|
124
|
+
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
125
|
+
arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
|
|
126
|
+
for arg in arg_list:
|
|
127
|
+
if not isinstance(arg, bool):
|
|
128
|
+
logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
|
|
129
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
164
130
|
|
|
165
131
|
|
|
166
132
|
def _check_json(json_file_handle, file_name):
|
|
@@ -198,28 +164,6 @@ def check_regex_prefix_format_valid(prefix):
|
|
|
198
164
|
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
199
165
|
|
|
200
166
|
|
|
201
|
-
def get_dump_data_path(dump_dir):
|
|
202
|
-
"""
|
|
203
|
-
Function Description:
|
|
204
|
-
traverse directories and obtain the absolute path of dump data
|
|
205
|
-
Parameter:
|
|
206
|
-
dump_dir: dump data directory
|
|
207
|
-
Return Value:
|
|
208
|
-
dump data path,file is exist or file is not exist
|
|
209
|
-
"""
|
|
210
|
-
dump_data_path = None
|
|
211
|
-
file_is_exist = False
|
|
212
|
-
|
|
213
|
-
check_file_or_directory_path(dump_dir, True)
|
|
214
|
-
for dir_path, _, files in os.walk(dump_dir):
|
|
215
|
-
if len(files) != 0:
|
|
216
|
-
dump_data_path = dir_path
|
|
217
|
-
file_is_exist = True
|
|
218
|
-
break
|
|
219
|
-
dump_data_path = dir_path
|
|
220
|
-
return dump_data_path, file_is_exist
|
|
221
|
-
|
|
222
|
-
|
|
223
167
|
def execute_command(cmd):
|
|
224
168
|
"""
|
|
225
169
|
Function Description:
|
|
@@ -241,22 +185,6 @@ def execute_command(cmd):
|
|
|
241
185
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
242
186
|
|
|
243
187
|
|
|
244
|
-
def parse_value_by_comma(value):
|
|
245
|
-
"""
|
|
246
|
-
parse value by comma, like '1,2,4,8'
|
|
247
|
-
"""
|
|
248
|
-
value_list = []
|
|
249
|
-
value_str_list = value.split(Const.COMMA)
|
|
250
|
-
for value_str in value_str_list:
|
|
251
|
-
value_str = value_str.strip()
|
|
252
|
-
if value_str.isdigit() or value_str == '-1':
|
|
253
|
-
value_list.append(int(value_str))
|
|
254
|
-
else:
|
|
255
|
-
logger.error("please check your input shape.")
|
|
256
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
257
|
-
return value_list
|
|
258
|
-
|
|
259
|
-
|
|
260
188
|
def add_time_as_suffix(name):
|
|
261
189
|
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
262
190
|
|
|
@@ -265,6 +193,10 @@ def add_time_with_xlsx(name):
|
|
|
265
193
|
return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
266
194
|
|
|
267
195
|
|
|
196
|
+
def add_time_with_yaml(name):
|
|
197
|
+
return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
198
|
+
|
|
199
|
+
|
|
268
200
|
def get_time():
|
|
269
201
|
return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
270
202
|
|
|
@@ -273,61 +205,6 @@ def format_value(value):
|
|
|
273
205
|
return float('{:.12f}'.format(value))
|
|
274
206
|
|
|
275
207
|
|
|
276
|
-
def check_seed_all(seed, mode):
|
|
277
|
-
if isinstance(seed, int):
|
|
278
|
-
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
279
|
-
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
280
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
281
|
-
else:
|
|
282
|
-
logger.error(f"Seed must be integer.")
|
|
283
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
284
|
-
if not isinstance(mode, bool):
|
|
285
|
-
logger.error(f"seed_all mode must be bool.")
|
|
286
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def get_process_rank(model):
|
|
290
|
-
logger.info("Rank id is not provided. Trying to get the rank id of the model.")
|
|
291
|
-
try:
|
|
292
|
-
local_device = next(model.parameters()).device
|
|
293
|
-
except StopIteration:
|
|
294
|
-
logger.warning('There is no parameter in the model. Fail to get rank id.')
|
|
295
|
-
return 0, False
|
|
296
|
-
if local_device.type == 'cpu':
|
|
297
|
-
logger.warning("Warning: the debugger is unable to get the rank id. "
|
|
298
|
-
"This may cause the dumpped data to be corrupted in the "
|
|
299
|
-
"case of distributed training. (You may ignore this if you are using only one card.) "
|
|
300
|
-
"Transfer the model to npu or gpu before register_hook() to avoid this warning.")
|
|
301
|
-
return 0, False
|
|
302
|
-
else:
|
|
303
|
-
return local_device.index, True
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
|
|
307
|
-
template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
|
|
308
|
-
pkl_dir = os.path.dirname(pkl_file_path)
|
|
309
|
-
compare_script_path = os.path.join(pkl_dir, "compare_data.py")
|
|
310
|
-
is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
|
|
311
|
-
|
|
312
|
-
try:
|
|
313
|
-
with FileOpen(template_path, 'r') as ftemp, \
|
|
314
|
-
os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
|
|
315
|
-
code_temp = ftemp.read()
|
|
316
|
-
fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
|
|
317
|
-
except OSError:
|
|
318
|
-
logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
|
|
319
|
-
|
|
320
|
-
logger.info(f"Generate compare script successfully which is {compare_script_path}.")
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def check_inplace_op(prefix):
|
|
324
|
-
if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
|
|
325
|
-
return False
|
|
326
|
-
match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
|
|
327
|
-
op_name = match_op[0] if match_op else None
|
|
328
|
-
return op_name in Const.INPLACE_LIST
|
|
329
|
-
|
|
330
|
-
|
|
331
208
|
def md5_find(data):
|
|
332
209
|
for key_op in data:
|
|
333
210
|
for api_info in data[key_op]:
|
|
@@ -340,6 +217,29 @@ def md5_find(data):
|
|
|
340
217
|
return False
|
|
341
218
|
|
|
342
219
|
|
|
220
|
+
def struct_json_get(input_param, framework):
|
|
221
|
+
if framework == Const.PT_FRAMEWORK:
|
|
222
|
+
prefix = "bench"
|
|
223
|
+
elif framework == Const.MS_FRAMEWORK:
|
|
224
|
+
prefix = "npu"
|
|
225
|
+
else:
|
|
226
|
+
logger.error("Error framework found.")
|
|
227
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
228
|
+
|
|
229
|
+
frame_json_path = input_param.get(f"{prefix}_json_path", None)
|
|
230
|
+
if not frame_json_path:
|
|
231
|
+
logger.error(f"Please check the json path is valid.")
|
|
232
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
233
|
+
directory = os.path.dirname(frame_json_path)
|
|
234
|
+
check_file_or_directory_path(directory, True)
|
|
235
|
+
stack_json = os.path.join(directory, "stack.json")
|
|
236
|
+
construct_json = os.path.join(directory, "construct.json")
|
|
237
|
+
|
|
238
|
+
stack = load_json(stack_json)
|
|
239
|
+
construct = load_json(construct_json)
|
|
240
|
+
return stack, construct
|
|
241
|
+
|
|
242
|
+
|
|
343
243
|
def task_dumppath_get(input_param):
|
|
344
244
|
npu_path = input_param.get("npu_json_path", None)
|
|
345
245
|
bench_path = input_param.get("bench_json_path", None)
|
|
@@ -383,3 +283,89 @@ def get_header_index(header_name, summary_compare=False):
|
|
|
383
283
|
|
|
384
284
|
def convert_tuple(data):
|
|
385
285
|
return data if isinstance(data, tuple) else (data, )
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def check_op_str_pattern_valid(string, op_name=None, stack=False):
|
|
289
|
+
if isinstance(string, str) and is_invalid_pattern(string):
|
|
290
|
+
if stack:
|
|
291
|
+
message = f"stack info of {op_name} contains special characters, please check!"
|
|
292
|
+
elif not op_name:
|
|
293
|
+
message = f"api name contains special characters, please check!"
|
|
294
|
+
else:
|
|
295
|
+
message = f"data info of {op_name} contains special characters, please check!"
|
|
296
|
+
logger.error(message)
|
|
297
|
+
raise CompareException(CompareException.INVALID_CHAR_ERROR)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def is_invalid_pattern(string):
|
|
301
|
+
pattern = Const.STRING_BLACKLIST
|
|
302
|
+
return re.search(pattern, string)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def print_tools_ends_info():
|
|
306
|
+
total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
307
|
+
logger.info('*' * total_len)
|
|
308
|
+
logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
309
|
+
logger.info('*' * total_len)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def get_step_or_rank_from_string(step_or_rank, obj):
|
|
313
|
+
splited = step_or_rank.split(Const.HYPHEN)
|
|
314
|
+
if len(splited) == 2:
|
|
315
|
+
try:
|
|
316
|
+
borderlines = int(splited[0]), int(splited[1])
|
|
317
|
+
except (ValueError, IndexError) as e:
|
|
318
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
319
|
+
"The hyphen(-) must start and end with decimal numbers.") from e
|
|
320
|
+
else:
|
|
321
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
322
|
+
f'The string parameter for {obj} only supports formats like "3-5". Now string parameter for {obj} is "{step_or_rank}".')
|
|
323
|
+
if all(Const.STEP_RANK_MAXIMUM_RANGE[0] <= b <= Const.STEP_RANK_MAXIMUM_RANGE[1] for b in borderlines):
|
|
324
|
+
if borderlines[0] <= borderlines[1]:
|
|
325
|
+
continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
|
|
326
|
+
else:
|
|
327
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
328
|
+
f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be greater than the right boundary ({borderlines[1]}).')
|
|
329
|
+
else:
|
|
330
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
331
|
+
f"The boundaries must fall within the range of [{Const.STEP_RANK_MAXIMUM_RANGE[0]}, {Const.STEP_RANK_MAXIMUM_RANGE[1]}].")
|
|
332
|
+
return continual_step_or_rank
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def get_real_step_or_rank(step_or_rank_input, obj):
|
|
336
|
+
if obj not in [Const.STEP, Const.RANK]:
|
|
337
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
338
|
+
f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
|
|
339
|
+
if step_or_rank_input is None:
|
|
340
|
+
return []
|
|
341
|
+
if not isinstance(step_or_rank_input, list):
|
|
342
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
|
|
343
|
+
real_step_or_rank = []
|
|
344
|
+
for element in step_or_rank_input:
|
|
345
|
+
if not isinstance(element, (int, str)):
|
|
346
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
347
|
+
f"{obj} element {element} must be an integer or string.")
|
|
348
|
+
if isinstance(element, int) and element < 0:
|
|
349
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
350
|
+
f"Each element of {obj} must be non-negative, currently it is {element}.")
|
|
351
|
+
if isinstance(element, int) and Const.STEP_RANK_MAXIMUM_RANGE[0] <= element <= Const.STEP_RANK_MAXIMUM_RANGE[1]:
|
|
352
|
+
real_step_or_rank.append(element)
|
|
353
|
+
elif isinstance(element, str) and Const.HYPHEN in element:
|
|
354
|
+
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
355
|
+
real_step_or_rank.extend(continual_step_or_rank)
|
|
356
|
+
real_step_or_rank = list(set(real_step_or_rank))
|
|
357
|
+
real_step_or_rank.sort()
|
|
358
|
+
return real_step_or_rank
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def check_seed_all(seed, mode):
|
|
362
|
+
if isinstance(seed, int):
|
|
363
|
+
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
364
|
+
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
365
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
366
|
+
else:
|
|
367
|
+
logger.error("Seed must be integer.")
|
|
368
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
369
|
+
if not isinstance(mode, bool):
|
|
370
|
+
logger.error("seed_all mode must be bool.")
|
|
371
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
msprobe/core/common_config.py
CHANGED
|
@@ -2,18 +2,17 @@ from msprobe.core.common.const import Const, FileCheckConst
|
|
|
2
2
|
from msprobe.core.common.log import logger
|
|
3
3
|
from msprobe.core.common.exceptions import MsprobeException
|
|
4
4
|
from msprobe.core.common.file_utils import FileChecker
|
|
5
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class CommonConfig:
|
|
8
9
|
def __init__(self, json_config):
|
|
9
10
|
self.task = json_config.get('task')
|
|
10
11
|
self.dump_path = json_config.get('dump_path')
|
|
11
|
-
self.rank = json_config.get('rank')
|
|
12
|
-
self.step = json_config.get('step')
|
|
12
|
+
self.rank = get_real_step_or_rank(json_config.get('rank'), Const.RANK)
|
|
13
|
+
self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
|
|
13
14
|
self.level = json_config.get('level')
|
|
14
|
-
self.seed = json_config.get('seed')
|
|
15
15
|
self.acl_config = json_config.get('acl_config')
|
|
16
|
-
self.is_deterministic = json_config.get('is_deterministic', False)
|
|
17
16
|
self.enable_dataloader = json_config.get('enable_dataloader', False)
|
|
18
17
|
self._check_config()
|
|
19
18
|
|
|
@@ -24,21 +23,9 @@ class CommonConfig:
|
|
|
24
23
|
if self.dump_path is not None and not isinstance(self.dump_path, str):
|
|
25
24
|
logger.error_log_with_exp("dump_path is invalid, it should be a string",
|
|
26
25
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
27
|
-
if self.rank is not None and not isinstance(self.rank, list):
|
|
28
|
-
logger.error_log_with_exp("rank is invalid, it should be a list",
|
|
29
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
30
|
-
if self.step is not None and not isinstance(self.step, list):
|
|
31
|
-
logger.error_log_with_exp("step is invalid, it should be a list",
|
|
32
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
33
26
|
if self.level and self.level not in Const.LEVEL_LIST:
|
|
34
27
|
logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST),
|
|
35
28
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
36
|
-
if self.seed is not None and not isinstance(self.seed, int):
|
|
37
|
-
logger.error_log_with_exp("seed is invalid, it should be an integer",
|
|
38
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
39
|
-
if not isinstance(self.is_deterministic, bool):
|
|
40
|
-
logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean",
|
|
41
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
42
29
|
if not isinstance(self.enable_dataloader, bool):
|
|
43
30
|
logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
|
|
44
31
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
@@ -73,13 +60,19 @@ class BaseConfig:
|
|
|
73
60
|
self.preheat_step = json_config.get("preheat_step")
|
|
74
61
|
self.max_sample = json_config.get("max_sample")
|
|
75
62
|
|
|
63
|
+
@staticmethod
|
|
64
|
+
def _check_str_list_config(config_item, config_name):
|
|
65
|
+
if config_item is not None:
|
|
66
|
+
if not isinstance(config_item, list):
|
|
67
|
+
logger.error_log_with_exp(f"{config_name} is invalid, it should be a list[str]",
|
|
68
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
69
|
+
for name in config_item:
|
|
70
|
+
if not isinstance(name, str):
|
|
71
|
+
logger.error_log_with_exp(f"{config_name} is invalid, it should be a list[str]",
|
|
72
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
73
|
+
|
|
76
74
|
def check_config(self):
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
logger.error_log_with_exp("list is invalid, it should be a list",
|
|
82
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
83
|
-
if self.data_mode is not None and not isinstance(self.data_mode, list):
|
|
84
|
-
logger.error_log_with_exp("data_mode is invalid, it should be a list",
|
|
85
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
75
|
+
self._check_str_list_config(self.scope, "scope")
|
|
76
|
+
self._check_str_list_config(self.list, "list")
|
|
77
|
+
self._check_str_list_config(self.data_mode, "data_mode")
|
|
78
|
+
self._check_str_list_config(self.backward_input, "backward_input")
|