mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.const import Const
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Runtime:
|
|
20
|
+
step_count: int = 0
|
|
21
|
+
rank_id: int = -1
|
|
22
|
+
is_running: bool = False
|
|
23
|
+
run_mode: str = Const.PYNATIVE_MODE
|
|
24
|
+
current_iter: int = 0
|
|
25
|
+
current_rank: None
|
msprobe/core/common/utils.py
CHANGED
|
@@ -18,6 +18,7 @@ import os
|
|
|
18
18
|
import re
|
|
19
19
|
import subprocess
|
|
20
20
|
import time
|
|
21
|
+
import inspect
|
|
21
22
|
from datetime import datetime, timezone
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
@@ -26,10 +27,15 @@ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_pa
|
|
|
26
27
|
from msprobe.core.common.const import Const, CompareConst
|
|
27
28
|
from msprobe.core.common.log import logger
|
|
28
29
|
from msprobe.core.common.exceptions import MsprobeException
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
32
34
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
35
|
+
file_suffix_to_file_type = {
|
|
36
|
+
"dump.json": Const.DUMP_JSON_FILE,
|
|
37
|
+
"debug.json": Const.DEBUG_JSON_FILE,
|
|
38
|
+
}
|
|
33
39
|
|
|
34
40
|
|
|
35
41
|
class MsprobeBaseException(Exception):
|
|
@@ -74,6 +80,7 @@ class MsprobeBaseException(Exception):
|
|
|
74
80
|
NAMES_STRUCTS_MATCH_ERROR = 34
|
|
75
81
|
INVALID_STATE_ERROR = 35
|
|
76
82
|
INVALID_API_NAME_ERROR = 36
|
|
83
|
+
CROSS_FRAME_ERROR = 37
|
|
77
84
|
|
|
78
85
|
def __init__(self, code, error_info: str = ""):
|
|
79
86
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -190,27 +197,6 @@ def check_regex_prefix_format_valid(prefix):
|
|
|
190
197
|
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
191
198
|
|
|
192
199
|
|
|
193
|
-
def execute_command(cmd):
|
|
194
|
-
"""
|
|
195
|
-
Function Description:
|
|
196
|
-
run the following command
|
|
197
|
-
Parameter:
|
|
198
|
-
cmd: command
|
|
199
|
-
Exception Description:
|
|
200
|
-
when invalid command throw exception
|
|
201
|
-
"""
|
|
202
|
-
logger.info('Execute command:%s' % cmd)
|
|
203
|
-
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
204
|
-
while process.poll() is None:
|
|
205
|
-
line = process.stdout.readline()
|
|
206
|
-
line = line.strip()
|
|
207
|
-
if line:
|
|
208
|
-
logger.info(line)
|
|
209
|
-
if process.returncode != 0:
|
|
210
|
-
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
211
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
212
|
-
|
|
213
|
-
|
|
214
200
|
def add_time_as_suffix(name):
|
|
215
201
|
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
216
202
|
|
|
@@ -231,17 +217,33 @@ def format_value(value):
|
|
|
231
217
|
return float('{:.12f}'.format(value))
|
|
232
218
|
|
|
233
219
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
220
|
+
@recursion_depth_decorator('msprobe.core.common.utils.md5_find', max_depth=Const.DUMP_MAX_DEPTH)
|
|
221
|
+
def md5_find(data, json_type=Const.DUMP_JSON_FILE):
|
|
222
|
+
if json_type == Const.DUMP_JSON_FILE:
|
|
223
|
+
for key_op in data:
|
|
224
|
+
for api_info in data[key_op]:
|
|
225
|
+
if isinstance(data[key_op][api_info], list):
|
|
226
|
+
for data_detail in data[key_op][api_info]:
|
|
227
|
+
if data_detail and Const.MD5 in data_detail:
|
|
228
|
+
return True
|
|
229
|
+
if isinstance(data[key_op][api_info], bool):
|
|
230
|
+
continue
|
|
231
|
+
elif data[key_op][api_info] and Const.MD5 in data[key_op][api_info]:
|
|
232
|
+
return True
|
|
233
|
+
elif json_type == Const.DEBUG_JSON_FILE:
|
|
234
|
+
if isinstance(data, dict):
|
|
235
|
+
if Const.MD5 in data:
|
|
244
236
|
return True
|
|
237
|
+
else:
|
|
238
|
+
for _, data_info in data.items():
|
|
239
|
+
if md5_find(data_info, Const.DEBUG_JSON_FILE):
|
|
240
|
+
return True
|
|
241
|
+
elif isinstance(data, list):
|
|
242
|
+
for data_info in data:
|
|
243
|
+
if md5_find(data_info, Const.DEBUG_JSON_FILE):
|
|
244
|
+
return True
|
|
245
|
+
else:
|
|
246
|
+
return False
|
|
245
247
|
return False
|
|
246
248
|
|
|
247
249
|
|
|
@@ -279,13 +281,41 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
|
279
281
|
def set_dump_path(input_param):
|
|
280
282
|
npu_path = input_param.get("npu_json_path", None)
|
|
281
283
|
bench_path = input_param.get("bench_json_path", None)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
284
|
+
dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \
|
|
285
|
+
bench_path is not None and bench_path.endswith("dump.json")
|
|
286
|
+
debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \
|
|
287
|
+
bench_path is not None and bench_path.endswith("debug.json")
|
|
288
|
+
if not dump_json_path_valid and not debug_json_path_valid:
|
|
285
289
|
logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
|
|
286
290
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
287
|
-
input_param[
|
|
288
|
-
input_param[
|
|
291
|
+
input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
292
|
+
input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def get_file_type(file_path):
|
|
296
|
+
if not isinstance(file_path, str):
|
|
297
|
+
logger.error("get_file_type failed, check the type of file_path.")
|
|
298
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
299
|
+
file_type = file_suffix_to_file_type.get(file_path.split(Const.SCOPE_SEPARATOR)[-1])
|
|
300
|
+
if file_type is None:
|
|
301
|
+
logger.error("get_file_type failed, file_path is neither dump.json nor debug.json.")
|
|
302
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
303
|
+
return file_type
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def check_dump_json_key(json_data, device_type):
|
|
307
|
+
task = json_data.get('task', None)
|
|
308
|
+
if not task:
|
|
309
|
+
logger.error(f"Task for {device_type} is empty, please check.")
|
|
310
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
311
|
+
if 'data' not in json_data:
|
|
312
|
+
logger.error(f"Missing 'data' in dump.json, please check dump.json of {device_type}.")
|
|
313
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
314
|
+
api_data = json_data.get('data')
|
|
315
|
+
if not isinstance(api_data, dict):
|
|
316
|
+
logger.error(f"Invalid type for 'data': expected a dict. Please check dump.json of {device_type}.")
|
|
317
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
318
|
+
return task, api_data
|
|
289
319
|
|
|
290
320
|
|
|
291
321
|
def get_dump_mode(input_param):
|
|
@@ -293,13 +323,10 @@ def get_dump_mode(input_param):
|
|
|
293
323
|
bench_path = input_param.get("bench_json_path", None)
|
|
294
324
|
npu_json_data = load_json(npu_path)
|
|
295
325
|
bench_json_data = load_json(bench_path)
|
|
326
|
+
json_type = get_file_type(file_path=npu_path)
|
|
296
327
|
|
|
297
|
-
npu_task = npu_json_data
|
|
298
|
-
bench_task = bench_json_data
|
|
299
|
-
|
|
300
|
-
if not npu_task or not bench_task:
|
|
301
|
-
logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
|
|
302
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
328
|
+
npu_task, npu_api_data = check_dump_json_key(npu_json_data, 'npu')
|
|
329
|
+
bench_task, bench_api_data = check_dump_json_key(bench_json_data, 'bench')
|
|
303
330
|
|
|
304
331
|
if npu_task != bench_task:
|
|
305
332
|
logger.error(f"Please check the dump task is consistent.")
|
|
@@ -312,8 +339,8 @@ def get_dump_mode(input_param):
|
|
|
312
339
|
return Const.STRUCTURE
|
|
313
340
|
|
|
314
341
|
if npu_task == Const.STATISTICS:
|
|
315
|
-
npu_md5_compare = md5_find(
|
|
316
|
-
bench_md5_compare = md5_find(
|
|
342
|
+
npu_md5_compare = md5_find(npu_api_data, json_type)
|
|
343
|
+
bench_md5_compare = md5_find(bench_api_data, json_type)
|
|
317
344
|
if npu_md5_compare == bench_md5_compare:
|
|
318
345
|
return Const.MD5 if npu_md5_compare else Const.SUMMARY
|
|
319
346
|
else:
|
|
@@ -436,6 +463,28 @@ def check_init_step(step):
|
|
|
436
463
|
f"{step} must be greater than or equal to 0")
|
|
437
464
|
|
|
438
465
|
|
|
466
|
+
def check_token_range(token_range):
|
|
467
|
+
if token_range is None:
|
|
468
|
+
return
|
|
469
|
+
if not isinstance(token_range, (list, tuple)):
|
|
470
|
+
logger.error("Token_range must be a list or tuple.")
|
|
471
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
472
|
+
if len(token_range) != 2:
|
|
473
|
+
logger.error("Token_range must contains exactly 2 elements.")
|
|
474
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
475
|
+
|
|
476
|
+
start, end = token_range
|
|
477
|
+
if not isinstance(start, int) or not isinstance(end, int):
|
|
478
|
+
logger.error("Start and end in token_range must be integer.")
|
|
479
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
480
|
+
if start > end:
|
|
481
|
+
logger.error("Start in token_range must less than the end.")
|
|
482
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
483
|
+
if start < 0:
|
|
484
|
+
logger.error("Start in token_range must >= 0.")
|
|
485
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
486
|
+
|
|
487
|
+
|
|
439
488
|
def check_seed_all(seed, mode, rm_dropout):
|
|
440
489
|
if is_int(seed):
|
|
441
490
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
@@ -505,4 +554,46 @@ def is_save_variable_valid(variable, valid_special_types, depth=0):
|
|
|
505
554
|
return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
|
|
506
555
|
for key, value in variable.items())
|
|
507
556
|
else:
|
|
508
|
-
return False
|
|
557
|
+
return False
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def replace_last_occurrence(text, old, new):
|
|
561
|
+
if text is None:
|
|
562
|
+
return text
|
|
563
|
+
index = text.rfind(old)
|
|
564
|
+
if index != -1:
|
|
565
|
+
return text[:index] + text[index:].replace(old, new, 1)
|
|
566
|
+
return text
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def load_stack_json(stack_path):
|
|
570
|
+
stack_dict = load_json(stack_path)
|
|
571
|
+
if not stack_dict.get(Const.NEW_STACK_FLAG):
|
|
572
|
+
return stack_dict
|
|
573
|
+
|
|
574
|
+
new_stack_dict = {}
|
|
575
|
+
for stack_info in stack_dict.values():
|
|
576
|
+
if len(stack_info) != 2:
|
|
577
|
+
continue
|
|
578
|
+
api_list, stack_str = stack_info
|
|
579
|
+
for api_name in api_list:
|
|
580
|
+
new_stack_dict.update({api_name: stack_str})
|
|
581
|
+
return new_stack_dict
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def analyze_api_call_stack(name):
|
|
585
|
+
try:
|
|
586
|
+
api_stack = inspect.stack()[2:]
|
|
587
|
+
except Exception as e:
|
|
588
|
+
logger.warning(f"The call stack of {name} failed to retrieve, {e}.")
|
|
589
|
+
api_stack = None
|
|
590
|
+
stack_str = []
|
|
591
|
+
if api_stack:
|
|
592
|
+
for (_, path, line, func, code, _) in api_stack:
|
|
593
|
+
if not code:
|
|
594
|
+
continue
|
|
595
|
+
stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n"
|
|
596
|
+
stack_str.append(stack_line)
|
|
597
|
+
else:
|
|
598
|
+
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
599
|
+
return "".join(stack_str)
|
msprobe/core/common_config.py
CHANGED
|
@@ -111,3 +111,10 @@ class BaseConfig:
|
|
|
111
111
|
f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
|
|
112
112
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
113
113
|
)
|
|
114
|
+
|
|
115
|
+
def _check_summary_mode(self):
|
|
116
|
+
if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE:
|
|
117
|
+
logger.error_log_with_exp(
|
|
118
|
+
f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.",
|
|
119
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
120
|
+
)
|