mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import multiprocessing
|
|
17
|
+
from multiprocessing.shared_memory import SharedMemory
|
|
18
|
+
import random
|
|
19
|
+
import time
|
|
20
|
+
import atexit
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_main_process():
|
|
27
|
+
return multiprocessing.current_process().name == 'MainProcess'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GlobalLock:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self.name = self.get_lock_name()
|
|
33
|
+
try:
|
|
34
|
+
self._shm = SharedMemory(create=False, name=self.name)
|
|
35
|
+
time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁
|
|
36
|
+
except FileNotFoundError:
|
|
37
|
+
try:
|
|
38
|
+
self._shm = SharedMemory(create=True, name=self.name, size=1)
|
|
39
|
+
self._shm.buf[0] = 0
|
|
40
|
+
logger.debug(f'{self.name} is created.')
|
|
41
|
+
except FileExistsError:
|
|
42
|
+
self.__init__()
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def get_lock_name(cls):
|
|
46
|
+
if is_main_process():
|
|
47
|
+
return f'global_lock_{os.getpid()}'
|
|
48
|
+
return f'global_lock_{os.getppid()}'
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def is_lock_exist(cls):
|
|
52
|
+
try:
|
|
53
|
+
SharedMemory(create=False, name=cls.get_lock_name()).close()
|
|
54
|
+
return True
|
|
55
|
+
except FileNotFoundError:
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def cleanup(self):
|
|
59
|
+
self._shm.close()
|
|
60
|
+
if is_main_process():
|
|
61
|
+
try:
|
|
62
|
+
self._shm.unlink()
|
|
63
|
+
logger.debug(f'{self.name} is unlinked.')
|
|
64
|
+
except FileNotFoundError:
|
|
65
|
+
logger.warning(f'{self.name} has already been unlinked.')
|
|
66
|
+
|
|
67
|
+
def acquire(self, timeout=180):
|
|
68
|
+
"""
|
|
69
|
+
acquire global lock, default timeout is 3 minutes.
|
|
70
|
+
|
|
71
|
+
:param float timeout: timeout(seconds), default value is 180.
|
|
72
|
+
"""
|
|
73
|
+
start = time.time()
|
|
74
|
+
while time.time() - start < timeout:
|
|
75
|
+
if self._shm.buf[0] == 0:
|
|
76
|
+
self._shm.buf[0] = 1
|
|
77
|
+
return
|
|
78
|
+
time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms
|
|
79
|
+
self._shm.buf[0] = 1
|
|
80
|
+
|
|
81
|
+
def release(self):
|
|
82
|
+
self._shm.buf[0] = 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
global_lock = GlobalLock()
|
|
86
|
+
atexit.register(global_lock.cleanup)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.const import Const
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Runtime:
|
|
20
|
+
step_count: int = 0
|
|
21
|
+
rank_id: int = -1
|
|
22
|
+
is_running: bool = False
|
|
23
|
+
run_mode: str = Const.PYNATIVE_MODE
|
|
24
|
+
current_iter: int = 0
|
|
25
|
+
current_rank: None
|
msprobe/core/common/utils.py
CHANGED
|
@@ -18,9 +18,8 @@ import os
|
|
|
18
18
|
import re
|
|
19
19
|
import subprocess
|
|
20
20
|
import time
|
|
21
|
-
|
|
21
|
+
import inspect
|
|
22
22
|
from datetime import datetime, timezone
|
|
23
|
-
from functools import wraps
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
|
|
@@ -28,10 +27,15 @@ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_pa
|
|
|
28
27
|
from msprobe.core.common.const import Const, CompareConst
|
|
29
28
|
from msprobe.core.common.log import logger
|
|
30
29
|
from msprobe.core.common.exceptions import MsprobeException
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
34
34
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
35
|
+
file_suffix_to_file_type = {
|
|
36
|
+
"dump.json": Const.DUMP_JSON_FILE,
|
|
37
|
+
"debug.json": Const.DEBUG_JSON_FILE,
|
|
38
|
+
}
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
class MsprobeBaseException(Exception):
|
|
@@ -75,6 +79,8 @@ class MsprobeBaseException(Exception):
|
|
|
75
79
|
MERGE_COMPARE_RESULT_ERROR = 33
|
|
76
80
|
NAMES_STRUCTS_MATCH_ERROR = 34
|
|
77
81
|
INVALID_STATE_ERROR = 35
|
|
82
|
+
INVALID_API_NAME_ERROR = 36
|
|
83
|
+
CROSS_FRAME_ERROR = 37
|
|
78
84
|
|
|
79
85
|
def __init__(self, code, error_info: str = ""):
|
|
80
86
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -191,27 +197,6 @@ def check_regex_prefix_format_valid(prefix):
|
|
|
191
197
|
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
192
198
|
|
|
193
199
|
|
|
194
|
-
def execute_command(cmd):
|
|
195
|
-
"""
|
|
196
|
-
Function Description:
|
|
197
|
-
run the following command
|
|
198
|
-
Parameter:
|
|
199
|
-
cmd: command
|
|
200
|
-
Exception Description:
|
|
201
|
-
when invalid command throw exception
|
|
202
|
-
"""
|
|
203
|
-
logger.info('Execute command:%s' % cmd)
|
|
204
|
-
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
205
|
-
while process.poll() is None:
|
|
206
|
-
line = process.stdout.readline()
|
|
207
|
-
line = line.strip()
|
|
208
|
-
if line:
|
|
209
|
-
logger.info(line)
|
|
210
|
-
if process.returncode != 0:
|
|
211
|
-
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
212
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
213
|
-
|
|
214
|
-
|
|
215
200
|
def add_time_as_suffix(name):
|
|
216
201
|
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
217
202
|
|
|
@@ -232,21 +217,41 @@ def format_value(value):
|
|
|
232
217
|
return float('{:.12f}'.format(value))
|
|
233
218
|
|
|
234
219
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
220
|
+
@recursion_depth_decorator('msprobe.core.common.utils.md5_find', max_depth=Const.DUMP_MAX_DEPTH)
|
|
221
|
+
def md5_find(data, json_type=Const.DUMP_JSON_FILE):
|
|
222
|
+
if json_type == Const.DUMP_JSON_FILE:
|
|
223
|
+
for key_op in data:
|
|
224
|
+
for api_info in data[key_op]:
|
|
225
|
+
if isinstance(data[key_op][api_info], list):
|
|
226
|
+
for data_detail in data[key_op][api_info]:
|
|
227
|
+
if data_detail and Const.MD5 in data_detail:
|
|
228
|
+
return True
|
|
229
|
+
if isinstance(data[key_op][api_info], bool):
|
|
230
|
+
continue
|
|
231
|
+
elif data[key_op][api_info] and Const.MD5 in data[key_op][api_info]:
|
|
232
|
+
return True
|
|
233
|
+
elif json_type == Const.DEBUG_JSON_FILE:
|
|
234
|
+
if isinstance(data, dict):
|
|
235
|
+
if Const.MD5 in data:
|
|
245
236
|
return True
|
|
237
|
+
else:
|
|
238
|
+
for _, data_info in data.items():
|
|
239
|
+
if md5_find(data_info, Const.DEBUG_JSON_FILE):
|
|
240
|
+
return True
|
|
241
|
+
elif isinstance(data, list):
|
|
242
|
+
for data_info in data:
|
|
243
|
+
if md5_find(data_info, Const.DEBUG_JSON_FILE):
|
|
244
|
+
return True
|
|
245
|
+
else:
|
|
246
|
+
return False
|
|
246
247
|
return False
|
|
247
248
|
|
|
248
249
|
|
|
249
250
|
def detect_framework_by_dump_json(file_path):
|
|
251
|
+
json_data = load_json(file_path)
|
|
252
|
+
framework = json_data.get("framework", None)
|
|
253
|
+
if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]:
|
|
254
|
+
return framework
|
|
250
255
|
pattern_ms = r'"type":\s*"mindspore'
|
|
251
256
|
pattern_pt = r'"type":\s*"torch'
|
|
252
257
|
with FileOpen(file_path, 'r') as file:
|
|
@@ -276,13 +281,26 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
|
276
281
|
def set_dump_path(input_param):
|
|
277
282
|
npu_path = input_param.get("npu_json_path", None)
|
|
278
283
|
bench_path = input_param.get("bench_json_path", None)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
284
|
+
dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \
|
|
285
|
+
bench_path is not None and bench_path.endswith("dump.json")
|
|
286
|
+
debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \
|
|
287
|
+
bench_path is not None and bench_path.endswith("debug.json")
|
|
288
|
+
if not dump_json_path_valid and not debug_json_path_valid:
|
|
289
|
+
logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
|
|
290
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
291
|
+
input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
292
|
+
input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def get_file_type(file_path):
|
|
296
|
+
if not isinstance(file_path, str):
|
|
297
|
+
logger.error("get_file_type failed, check the type of file_path.")
|
|
298
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
299
|
+
file_type = file_suffix_to_file_type.get(file_path.split(Const.SCOPE_SEPARATOR)[-1])
|
|
300
|
+
if file_type is None:
|
|
301
|
+
logger.error("get_file_type failed, file_path is neither dump.json nor debug.json.")
|
|
283
302
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
284
|
-
|
|
285
|
-
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
303
|
+
return file_type
|
|
286
304
|
|
|
287
305
|
|
|
288
306
|
def get_dump_mode(input_param):
|
|
@@ -290,6 +308,7 @@ def get_dump_mode(input_param):
|
|
|
290
308
|
bench_path = input_param.get("bench_json_path", None)
|
|
291
309
|
npu_json_data = load_json(npu_path)
|
|
292
310
|
bench_json_data = load_json(bench_path)
|
|
311
|
+
json_type = get_file_type(file_path=npu_path)
|
|
293
312
|
|
|
294
313
|
npu_task = npu_json_data.get('task', None)
|
|
295
314
|
bench_task = bench_json_data.get('task', None)
|
|
@@ -309,8 +328,8 @@ def get_dump_mode(input_param):
|
|
|
309
328
|
return Const.STRUCTURE
|
|
310
329
|
|
|
311
330
|
if npu_task == Const.STATISTICS:
|
|
312
|
-
npu_md5_compare = md5_find(npu_json_data['data'])
|
|
313
|
-
bench_md5_compare = md5_find(bench_json_data['data'])
|
|
331
|
+
npu_md5_compare = md5_find(npu_json_data['data'], json_type)
|
|
332
|
+
bench_md5_compare = md5_find(bench_json_data['data'], json_type)
|
|
314
333
|
if npu_md5_compare == bench_md5_compare:
|
|
315
334
|
return Const.MD5 if npu_md5_compare else Const.SUMMARY
|
|
316
335
|
else:
|
|
@@ -424,6 +443,37 @@ def get_real_step_or_rank(step_or_rank_input, obj):
|
|
|
424
443
|
return real_step_or_rank
|
|
425
444
|
|
|
426
445
|
|
|
446
|
+
def check_init_step(step):
|
|
447
|
+
if not is_int(step):
|
|
448
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
449
|
+
f"{step} must be an integer")
|
|
450
|
+
if not step >= 0:
|
|
451
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
452
|
+
f"{step} must be greater than or equal to 0")
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def check_token_range(token_range):
|
|
456
|
+
if token_range is None:
|
|
457
|
+
return
|
|
458
|
+
if not isinstance(token_range, (list, tuple)):
|
|
459
|
+
logger.error("Token_range must be a list or tuple.")
|
|
460
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
461
|
+
if len(token_range) != 2:
|
|
462
|
+
logger.error("Token_range must contains exactly 2 elements.")
|
|
463
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
464
|
+
|
|
465
|
+
start, end = token_range
|
|
466
|
+
if not isinstance(start, int) or not isinstance(end, int):
|
|
467
|
+
logger.error("Start and end in token_range must be integer.")
|
|
468
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
469
|
+
if start > end:
|
|
470
|
+
logger.error("Start in token_range must less than the end.")
|
|
471
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
472
|
+
if start < 0:
|
|
473
|
+
logger.error("Start in token_range must >= 0.")
|
|
474
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
475
|
+
|
|
476
|
+
|
|
427
477
|
def check_seed_all(seed, mode, rm_dropout):
|
|
428
478
|
if is_int(seed):
|
|
429
479
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
@@ -467,36 +517,6 @@ def safe_get_value(container, index, container_name, key=None):
|
|
|
467
517
|
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
|
|
468
518
|
|
|
469
519
|
|
|
470
|
-
# 记录工具函数递归的深度
|
|
471
|
-
recursion_depth = defaultdict(int)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
|
|
475
|
-
def recursion_depth_decorator(func_info):
|
|
476
|
-
def decorator(func):
|
|
477
|
-
@wraps(func)
|
|
478
|
-
def wrapper(*args, **kwargs):
|
|
479
|
-
func_id = id(func)
|
|
480
|
-
recursion_depth[func_id] += 1
|
|
481
|
-
if recursion_depth[func_id] > Const.MAX_DEPTH:
|
|
482
|
-
msg = f"call {func_info} exceeds the recursion limit."
|
|
483
|
-
logger.error_log_with_exp(
|
|
484
|
-
msg,
|
|
485
|
-
MsprobeException(
|
|
486
|
-
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
487
|
-
),
|
|
488
|
-
)
|
|
489
|
-
try:
|
|
490
|
-
result = func(*args, **kwargs)
|
|
491
|
-
finally:
|
|
492
|
-
recursion_depth[func_id] -= 1
|
|
493
|
-
return result
|
|
494
|
-
|
|
495
|
-
return wrapper
|
|
496
|
-
|
|
497
|
-
return decorator
|
|
498
|
-
|
|
499
|
-
|
|
500
520
|
def check_str_param(param):
|
|
501
521
|
if not re.match(Const.REGEX_PREFIX_PATTERN, param):
|
|
502
522
|
logger.error('The parameter {} contains special characters.'.format(param))
|
|
@@ -509,4 +529,60 @@ class DumpPathAggregation:
|
|
|
509
529
|
construct_file_path = None
|
|
510
530
|
dump_tensor_data_dir = None
|
|
511
531
|
free_benchmark_file_path = None
|
|
512
|
-
debug_file_path = None
|
|
532
|
+
debug_file_path = None
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def is_save_variable_valid(variable, valid_special_types, depth=0):
|
|
536
|
+
if depth > Const.DUMP_MAX_DEPTH:
|
|
537
|
+
return False
|
|
538
|
+
if isinstance(variable, valid_special_types):
|
|
539
|
+
return True
|
|
540
|
+
elif isinstance(variable, (list, tuple)):
|
|
541
|
+
return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable)
|
|
542
|
+
elif isinstance(variable, dict):
|
|
543
|
+
return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
|
|
544
|
+
for key, value in variable.items())
|
|
545
|
+
else:
|
|
546
|
+
return False
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def replace_last_occurrence(text, old, new):
|
|
550
|
+
if text is None:
|
|
551
|
+
return text
|
|
552
|
+
index = text.rfind(old)
|
|
553
|
+
if index != -1:
|
|
554
|
+
return text[:index] + text[index:].replace(old, new, 1)
|
|
555
|
+
return text
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def load_stack_json(stack_path):
|
|
559
|
+
stack_dict = load_json(stack_path)
|
|
560
|
+
if not stack_dict.get(Const.NEW_STACK_FLAG):
|
|
561
|
+
return stack_dict
|
|
562
|
+
|
|
563
|
+
new_stack_dict = {}
|
|
564
|
+
for stack_info in stack_dict.values():
|
|
565
|
+
if len(stack_info) != 2:
|
|
566
|
+
continue
|
|
567
|
+
api_list, stack_str = stack_info
|
|
568
|
+
for api_name in api_list:
|
|
569
|
+
new_stack_dict.update({api_name: stack_str})
|
|
570
|
+
return new_stack_dict
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def analyze_api_call_stack(name):
|
|
574
|
+
try:
|
|
575
|
+
api_stack = inspect.stack()[2:]
|
|
576
|
+
except Exception as e:
|
|
577
|
+
logger.warning(f"The call stack of {name} failed to retrieve, {e}.")
|
|
578
|
+
api_stack = None
|
|
579
|
+
stack_str = []
|
|
580
|
+
if api_stack:
|
|
581
|
+
for (_, path, line, func, code, _) in api_stack:
|
|
582
|
+
if not code:
|
|
583
|
+
continue
|
|
584
|
+
stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n"
|
|
585
|
+
stack_str.append(stack_line)
|
|
586
|
+
else:
|
|
587
|
+
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
588
|
+
return "".join(stack_str)
|
msprobe/core/common_config.py
CHANGED
|
@@ -111,3 +111,10 @@ class BaseConfig:
|
|
|
111
111
|
f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
|
|
112
112
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
113
113
|
)
|
|
114
|
+
|
|
115
|
+
def _check_summary_mode(self):
|
|
116
|
+
if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE:
|
|
117
|
+
logger.error_log_with_exp(
|
|
118
|
+
f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.",
|
|
119
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
120
|
+
)
|