mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
msprobe/core/common/log.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import time
|
|
3
18
|
import sys
|
|
@@ -5,6 +20,16 @@ from functools import wraps
|
|
|
5
20
|
from msprobe.core.common.const import MsgConst
|
|
6
21
|
|
|
7
22
|
|
|
23
|
+
def filter_special_chars(func):
|
|
24
|
+
@wraps(func)
|
|
25
|
+
def func_level(self, msg, **kwargs):
|
|
26
|
+
for char in MsgConst.SPECIAL_CHAR:
|
|
27
|
+
msg = msg.replace(char, '_')
|
|
28
|
+
return func(self, msg, **kwargs)
|
|
29
|
+
|
|
30
|
+
return func_level
|
|
31
|
+
|
|
32
|
+
|
|
8
33
|
class BaseLogger:
|
|
9
34
|
def __init__(self):
|
|
10
35
|
self.rank = None
|
|
@@ -21,14 +46,6 @@ class BaseLogger:
|
|
|
21
46
|
def get_rank(self):
|
|
22
47
|
return self.rank
|
|
23
48
|
|
|
24
|
-
def filter_special_chars(func):
|
|
25
|
-
@wraps(func)
|
|
26
|
-
def func_level(self, msg, **kwargs):
|
|
27
|
-
for char in MsgConst.SPECIAL_CHAR:
|
|
28
|
-
msg = msg.replace(char, '_')
|
|
29
|
-
return func(self, msg, **kwargs)
|
|
30
|
-
return func_level
|
|
31
|
-
|
|
32
49
|
@filter_special_chars
|
|
33
50
|
def error(self, msg):
|
|
34
51
|
if self.level <= MsgConst.LogLevel.ERROR.value:
|
|
@@ -56,6 +73,7 @@ class BaseLogger:
|
|
|
56
73
|
return func(*args, **kwargs)
|
|
57
74
|
else:
|
|
58
75
|
return None
|
|
76
|
+
|
|
59
77
|
return func_rank_0
|
|
60
78
|
|
|
61
79
|
def info_on_rank_0(self, msg):
|
|
@@ -66,7 +84,7 @@ class BaseLogger:
|
|
|
66
84
|
|
|
67
85
|
def warning_on_rank_0(self, msg):
|
|
68
86
|
return self.on_rank_0(self.warning)(msg)
|
|
69
|
-
|
|
87
|
+
|
|
70
88
|
def error_log_with_exp(self, msg, exception):
|
|
71
89
|
self.error(msg)
|
|
72
90
|
raise exception
|
msprobe/core/common/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,21 +12,23 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import collections
|
|
18
17
|
import os
|
|
19
18
|
import re
|
|
20
19
|
import subprocess
|
|
21
20
|
import time
|
|
22
|
-
import
|
|
21
|
+
from collections import defaultdict
|
|
23
22
|
from datetime import datetime, timezone
|
|
23
|
+
from functools import wraps
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
24
26
|
|
|
25
27
|
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
26
28
|
from msprobe.core.common.const import Const, CompareConst
|
|
27
29
|
from msprobe.core.common.log import logger
|
|
28
30
|
from msprobe.core.common.exceptions import MsprobeException
|
|
29
31
|
|
|
30
|
-
|
|
31
32
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
32
33
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
33
34
|
|
|
@@ -68,6 +69,8 @@ class MsprobeBaseException(Exception):
|
|
|
68
69
|
FUNCTION_CALL_ERROR = 28
|
|
69
70
|
FORWARD_DATA_COLLECTION_ERROR = 29
|
|
70
71
|
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
72
|
+
INVALID_KEY_ERROR = 31
|
|
73
|
+
MISSING_HEADER_ERROR = 32
|
|
71
74
|
|
|
72
75
|
def __init__(self, code, error_info: str = ""):
|
|
73
76
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -99,7 +102,14 @@ class DumpException(MsprobeBaseException):
|
|
|
99
102
|
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
100
103
|
|
|
101
104
|
|
|
102
|
-
def
|
|
105
|
+
def is_json_file(file_path):
|
|
106
|
+
if isinstance(file_path, str) and file_path.lower().endswith('.json'):
|
|
107
|
+
return True
|
|
108
|
+
else:
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def check_compare_param(input_param, output_path, dump_mode):
|
|
103
113
|
if not isinstance(input_param, dict):
|
|
104
114
|
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
105
115
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -107,10 +117,19 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
|
|
|
107
117
|
logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
|
|
108
118
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
109
119
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
120
|
+
def check_json_path(json_path_str):
|
|
121
|
+
json_path = input_param.get(json_path_str)
|
|
122
|
+
check_file_or_directory_path(json_path, False)
|
|
123
|
+
json_type_check = is_json_file(json_path)
|
|
124
|
+
if not json_type_check:
|
|
125
|
+
logger.error(f"Invalid {json_path_str}: {json_path}, please check!")
|
|
126
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
127
|
+
|
|
128
|
+
check_json_path("npu_json_path")
|
|
129
|
+
check_json_path("bench_json_path")
|
|
130
|
+
check_json_path("stack_json_path")
|
|
131
|
+
|
|
132
|
+
if dump_mode == Const.ALL:
|
|
114
133
|
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
115
134
|
check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
|
|
116
135
|
check_file_or_directory_path(output_path, True)
|
|
@@ -179,7 +198,7 @@ def execute_command(cmd):
|
|
|
179
198
|
line = process.stdout.readline()
|
|
180
199
|
line = line.strip()
|
|
181
200
|
if line:
|
|
182
|
-
|
|
201
|
+
logger.info(line)
|
|
183
202
|
if process.returncode != 0:
|
|
184
203
|
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
185
204
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
@@ -212,25 +231,29 @@ def md5_find(data):
|
|
|
212
231
|
for data_detail in data[key_op][api_info]:
|
|
213
232
|
if data_detail and 'md5' in data_detail:
|
|
214
233
|
return True
|
|
215
|
-
elif 'md5' in data[key_op][api_info]:
|
|
234
|
+
elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
|
|
216
235
|
return True
|
|
217
236
|
return False
|
|
218
237
|
|
|
219
238
|
|
|
220
|
-
def
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
239
|
+
def detect_framework_by_dump_json(file_path):
|
|
240
|
+
pattern_ms = r'"type":\s*"mindspore'
|
|
241
|
+
pattern_pt = r'"type":\s*"torch'
|
|
242
|
+
with FileOpen(file_path, 'r') as file:
|
|
243
|
+
for line in file:
|
|
244
|
+
if re.search(pattern_ms, line):
|
|
245
|
+
return Const.MS_FRAMEWORK
|
|
246
|
+
if re.search(pattern_pt, line):
|
|
247
|
+
return Const.PT_FRAMEWORK
|
|
248
|
+
logger.error(f"{file_path} must be based on the MindSpore or PyTorch framework.")
|
|
249
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
250
|
+
|
|
228
251
|
|
|
229
|
-
|
|
230
|
-
if not
|
|
231
|
-
logger.error(
|
|
252
|
+
def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
253
|
+
if not dump_json_path:
|
|
254
|
+
logger.error("The path is empty. Please enter a valid path.")
|
|
232
255
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
233
|
-
directory = os.path.dirname(
|
|
256
|
+
directory = os.path.dirname(dump_json_path)
|
|
234
257
|
check_file_or_directory_path(directory, True)
|
|
235
258
|
stack_json = os.path.join(directory, "stack.json")
|
|
236
259
|
construct_json = os.path.join(directory, "construct.json")
|
|
@@ -240,41 +263,57 @@ def struct_json_get(input_param, framework):
|
|
|
240
263
|
return stack, construct
|
|
241
264
|
|
|
242
265
|
|
|
243
|
-
def
|
|
266
|
+
def set_dump_path(input_param):
|
|
244
267
|
npu_path = input_param.get("npu_json_path", None)
|
|
245
268
|
bench_path = input_param.get("bench_json_path", None)
|
|
246
|
-
|
|
247
|
-
|
|
269
|
+
npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
|
|
270
|
+
bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
|
|
271
|
+
if not npu_path_valid or not bench_path_valid:
|
|
272
|
+
logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
|
|
248
273
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
274
|
+
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
275
|
+
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_dump_mode(input_param):
|
|
279
|
+
npu_path = input_param.get("npu_json_path", None)
|
|
280
|
+
bench_path = input_param.get("bench_json_path", None)
|
|
281
|
+
npu_json_data = load_json(npu_path)
|
|
282
|
+
bench_json_data = load_json(bench_path)
|
|
283
|
+
|
|
284
|
+
npu_task = npu_json_data.get('task', None)
|
|
285
|
+
bench_task = bench_json_data.get('task', None)
|
|
286
|
+
|
|
287
|
+
if not npu_task or not bench_task:
|
|
288
|
+
logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
|
|
289
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
290
|
+
|
|
291
|
+
if npu_task != bench_task:
|
|
254
292
|
logger.error(f"Please check the dump task is consistent.")
|
|
255
293
|
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
294
|
+
|
|
295
|
+
if npu_task == Const.TENSOR:
|
|
296
|
+
return Const.ALL
|
|
297
|
+
|
|
298
|
+
if npu_task == Const.STATISTICS:
|
|
299
|
+
npu_md5_compare = md5_find(npu_json_data['data'])
|
|
300
|
+
bench_md5_compare = md5_find(bench_json_data['data'])
|
|
301
|
+
if npu_md5_compare == bench_md5_compare:
|
|
302
|
+
return Const.MD5 if npu_md5_compare else Const.SUMMARY
|
|
263
303
|
else:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
268
|
-
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
269
|
-
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
270
|
-
return summary_compare, md5_compare
|
|
304
|
+
logger.error(f"Please check the dump task is consistent, "
|
|
305
|
+
f"dump mode of npu and bench should both be statistics or md5.")
|
|
306
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
271
307
|
|
|
308
|
+
logger.error(f"Compare applies only to task is tensor or statistics")
|
|
309
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
272
310
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
311
|
+
|
|
312
|
+
def get_header_index(header_name, dump_mode):
|
|
313
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE.get(dump_mode)
|
|
314
|
+
if not header:
|
|
315
|
+
logger.error(f"{dump_mode} not in {CompareConst.HEAD_OF_COMPARE_MODE}")
|
|
316
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
278
317
|
if header_name not in header:
|
|
279
318
|
logger.error(f"{header_name} not in data name")
|
|
280
319
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -282,7 +321,7 @@ def get_header_index(header_name, summary_compare=False):
|
|
|
282
321
|
|
|
283
322
|
|
|
284
323
|
def convert_tuple(data):
|
|
285
|
-
return data if isinstance(data, tuple) else (data,
|
|
324
|
+
return data if isinstance(data, tuple) else (data,)
|
|
286
325
|
|
|
287
326
|
|
|
288
327
|
def check_op_str_pattern_valid(string, op_name=None, stack=False):
|
|
@@ -302,6 +341,10 @@ def is_invalid_pattern(string):
|
|
|
302
341
|
return re.search(pattern, string)
|
|
303
342
|
|
|
304
343
|
|
|
344
|
+
def is_int(x):
|
|
345
|
+
return isinstance(x, int) and not isinstance(x, bool)
|
|
346
|
+
|
|
347
|
+
|
|
305
348
|
def print_tools_ends_info():
|
|
306
349
|
total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
307
350
|
logger.info('*' * total_len)
|
|
@@ -315,40 +358,47 @@ def get_step_or_rank_from_string(step_or_rank, obj):
|
|
|
315
358
|
try:
|
|
316
359
|
borderlines = int(splited[0]), int(splited[1])
|
|
317
360
|
except (ValueError, IndexError) as e:
|
|
318
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
361
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
319
362
|
"The hyphen(-) must start and end with decimal numbers.") from e
|
|
320
363
|
else:
|
|
321
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
322
|
-
f'The string parameter for {obj} only supports formats like "3-5".
|
|
323
|
-
|
|
364
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
365
|
+
f'The string parameter for {obj} only supports formats like "3-5". '
|
|
366
|
+
f'Now string parameter for {obj} is "{step_or_rank}".')
|
|
367
|
+
if all(Const.STEP_RANK_MINIMUM_VALUE <= b <= Const.STEP_RANK_MAXIMUM_VALUE for b in borderlines):
|
|
324
368
|
if borderlines[0] <= borderlines[1]:
|
|
325
369
|
continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
|
|
326
370
|
else:
|
|
327
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
328
|
-
|
|
371
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
372
|
+
f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be '
|
|
373
|
+
f'greater than the right boundary ({borderlines[1]}).')
|
|
329
374
|
else:
|
|
330
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
331
|
-
f"The boundaries must fall within the range of
|
|
375
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
376
|
+
f"The boundaries must fall within the range of "
|
|
377
|
+
f"[{Const.STEP_RANK_MINIMUM_VALUE}, {Const.STEP_RANK_MAXIMUM_VALUE}].")
|
|
332
378
|
return continual_step_or_rank
|
|
333
379
|
|
|
334
380
|
|
|
335
381
|
def get_real_step_or_rank(step_or_rank_input, obj):
|
|
336
382
|
if obj not in [Const.STEP, Const.RANK]:
|
|
337
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
383
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
338
384
|
f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
|
|
339
385
|
if step_or_rank_input is None:
|
|
340
386
|
return []
|
|
341
387
|
if not isinstance(step_or_rank_input, list):
|
|
342
388
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
|
|
389
|
+
if len(step_or_rank_input) > Const.STEP_RANK_MAXIMUM_VALUE:
|
|
390
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
391
|
+
f"{obj} is invalid, its length cannot exceed {Const.STEP_RANK_MAXIMUM_VALUE}")
|
|
392
|
+
|
|
343
393
|
real_step_or_rank = []
|
|
344
394
|
for element in step_or_rank_input:
|
|
345
|
-
if not
|
|
346
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
395
|
+
if not is_int(element) and not isinstance(element, str):
|
|
396
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
347
397
|
f"{obj} element {element} must be an integer or string.")
|
|
348
398
|
if isinstance(element, int) and element < 0:
|
|
349
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
399
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
350
400
|
f"Each element of {obj} must be non-negative, currently it is {element}.")
|
|
351
|
-
if isinstance(element, int) and Const.
|
|
401
|
+
if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
|
|
352
402
|
real_step_or_rank.append(element)
|
|
353
403
|
elif isinstance(element, str) and Const.HYPHEN in element:
|
|
354
404
|
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
@@ -359,7 +409,7 @@ def get_real_step_or_rank(step_or_rank_input, obj):
|
|
|
359
409
|
|
|
360
410
|
|
|
361
411
|
def check_seed_all(seed, mode):
|
|
362
|
-
if
|
|
412
|
+
if is_int(seed):
|
|
363
413
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
364
414
|
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
365
415
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
@@ -369,3 +419,66 @@ def check_seed_all(seed, mode):
|
|
|
369
419
|
if not isinstance(mode, bool):
|
|
370
420
|
logger.error("seed_all mode must be bool.")
|
|
371
421
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def safe_get_value(container, index, container_name, key=None):
|
|
425
|
+
try:
|
|
426
|
+
# 处理字典情况
|
|
427
|
+
if isinstance(container, dict):
|
|
428
|
+
return container.get(key)[index]
|
|
429
|
+
# 处理列表、元组、numpy情况
|
|
430
|
+
elif isinstance(container, (list, tuple, np.ndarray)):
|
|
431
|
+
return container[index]
|
|
432
|
+
else:
|
|
433
|
+
err_msg = f"Unsupported container type for '{container_name}': {type(container)}"
|
|
434
|
+
logger.error(err_msg)
|
|
435
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR)
|
|
436
|
+
except IndexError as e:
|
|
437
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
438
|
+
f"{container_name} is {container}\n" \
|
|
439
|
+
f"index is {index}"
|
|
440
|
+
logger.error(err_msg)
|
|
441
|
+
raise MsprobeBaseException(MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
442
|
+
except TypeError as e:
|
|
443
|
+
err_msg = "wrong type, please check!\n" \
|
|
444
|
+
f"{container_name} is {container}\n" \
|
|
445
|
+
f"index is {index}\n" \
|
|
446
|
+
f"key is {key}"
|
|
447
|
+
logger.error(err_msg)
|
|
448
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
# 记录工具函数递归的深度
|
|
452
|
+
recursion_depth = defaultdict(int)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
|
|
456
|
+
def recursion_depth_decorator(func_info):
|
|
457
|
+
def decorator(func):
|
|
458
|
+
@wraps(func)
|
|
459
|
+
def wrapper(*args, **kwargs):
|
|
460
|
+
func_id = id(func)
|
|
461
|
+
recursion_depth[func_id] += 1
|
|
462
|
+
if recursion_depth[func_id] > Const.MAX_DEPTH:
|
|
463
|
+
msg = f"call {func_info} exceeds the recursion limit."
|
|
464
|
+
logger.error_log_with_exp(
|
|
465
|
+
msg,
|
|
466
|
+
MsprobeException(
|
|
467
|
+
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
try:
|
|
471
|
+
result = func(*args, **kwargs)
|
|
472
|
+
finally:
|
|
473
|
+
recursion_depth[func_id] -= 1
|
|
474
|
+
return result
|
|
475
|
+
|
|
476
|
+
return wrapper
|
|
477
|
+
|
|
478
|
+
return decorator
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def check_str_param(param):
|
|
482
|
+
if not re.match(Const.REGEX_PREFIX_PATTERN, param):
|
|
483
|
+
logger.error('The parameter {} contains special characters.'.format(param))
|
|
484
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
|
msprobe/core/common_config.py
CHANGED
|
@@ -1,7 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
2
17
|
from msprobe.core.common.log import logger
|
|
3
18
|
from msprobe.core.common.exceptions import MsprobeException
|
|
4
|
-
from msprobe.core.common.file_utils import FileChecker
|
|
5
19
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
6
20
|
|
|
7
21
|
|
|
@@ -12,7 +26,6 @@ class CommonConfig:
|
|
|
12
26
|
self.rank = get_real_step_or_rank(json_config.get('rank'), Const.RANK)
|
|
13
27
|
self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
|
|
14
28
|
self.level = json_config.get('level')
|
|
15
|
-
self.acl_config = json_config.get('acl_config')
|
|
16
29
|
self.enable_dataloader = json_config.get('enable_dataloader', False)
|
|
17
30
|
self._check_config()
|
|
18
31
|
|
|
@@ -29,16 +42,6 @@ class CommonConfig:
|
|
|
29
42
|
if not isinstance(self.enable_dataloader, bool):
|
|
30
43
|
logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
|
|
31
44
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
32
|
-
if self.acl_config:
|
|
33
|
-
self._check_acl_config()
|
|
34
|
-
|
|
35
|
-
def _check_acl_config(self):
|
|
36
|
-
if not isinstance(self.acl_config, str):
|
|
37
|
-
logger.error_log_with_exp("acl_config is invalid, it should be a string",
|
|
38
|
-
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
39
|
-
file_checker = FileChecker(
|
|
40
|
-
file_path=self.acl_config, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
41
|
-
file_checker.common_check()
|
|
42
45
|
|
|
43
46
|
|
|
44
47
|
class BaseConfig:
|
|
@@ -46,7 +49,6 @@ class BaseConfig:
|
|
|
46
49
|
self.scope = json_config.get('scope')
|
|
47
50
|
self.list = json_config.get('list')
|
|
48
51
|
self.data_mode = json_config.get('data_mode')
|
|
49
|
-
self.backward_input = json_config.get("backward_input")
|
|
50
52
|
self.file_format = json_config.get("file_format")
|
|
51
53
|
self.summary_mode = json_config.get("summary_mode")
|
|
52
54
|
self.overflow_nums = json_config.get("overflow_nums")
|
|
@@ -74,5 +76,32 @@ class BaseConfig:
|
|
|
74
76
|
def check_config(self):
|
|
75
77
|
self._check_str_list_config(self.scope, "scope")
|
|
76
78
|
self._check_str_list_config(self.list, "list")
|
|
77
|
-
self.
|
|
78
|
-
|
|
79
|
+
self._check_data_mode()
|
|
80
|
+
|
|
81
|
+
def _check_data_mode(self):
|
|
82
|
+
if self.data_mode is not None:
|
|
83
|
+
if not isinstance(self.data_mode, list):
|
|
84
|
+
logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
|
|
85
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
86
|
+
|
|
87
|
+
if Const.ALL in self.data_mode and len(self.data_mode) != 1:
|
|
88
|
+
logger.error_log_with_exp(
|
|
89
|
+
"'all' cannot be combined with other options in data_mode.",
|
|
90
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if len(self.data_mode) >= len(Const.DUMP_DATA_MODE_LIST):
|
|
94
|
+
logger.error_log_with_exp(
|
|
95
|
+
f"The number of elements in the data_made cannot exceed {len(Const.DUMP_DATA_MODE_LIST) - 1}.",
|
|
96
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
for mode in self.data_mode:
|
|
100
|
+
if not isinstance(mode, str):
|
|
101
|
+
logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
|
|
102
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
103
|
+
if mode not in Const.DUMP_DATA_MODE_LIST:
|
|
104
|
+
logger.error_log_with_exp(
|
|
105
|
+
f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
|
|
106
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
107
|
+
)
|