mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -27,10 +27,14 @@ import numpy as np
|
|
|
27
27
|
from tqdm import tqdm
|
|
28
28
|
|
|
29
29
|
# 本地应用/库特定导入
|
|
30
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
30
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
31
31
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
|
|
32
32
|
from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
|
|
33
33
|
from msprobe.mindspore.common.log import logger
|
|
34
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
35
|
+
|
|
36
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
37
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
@@ -50,6 +54,12 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
50
54
|
# 初始化一个属性来存储当前的设备ID(用于日志中显示)
|
|
51
55
|
self.current_device_id = None
|
|
52
56
|
|
|
57
|
+
self.save_error_data = args.save_error_data
|
|
58
|
+
if self.save_error_data:
|
|
59
|
+
config, dump_path_aggregation = self.init_save_error_data(args)
|
|
60
|
+
self.data_collector = build_data_collector(config)
|
|
61
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
62
|
+
|
|
53
63
|
def process_on_device(self, device_id, api_infos, progress_queue):
|
|
54
64
|
"""
|
|
55
65
|
在特定设备上处理一部分API。
|
|
@@ -19,7 +19,8 @@ import sys
|
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
import mindspore
|
|
21
21
|
from msprobe.mindspore.common.log import logger
|
|
22
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
22
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
23
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
23
24
|
import torch as mindtorch
|
|
24
25
|
from torch import Tensor as mindtorch_tensor
|
|
25
26
|
import torch.nn.functional as mindtorch_func
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,21 +13,50 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
|
|
18
|
+
from mindspore import Tensor
|
|
19
|
+
from mindspore.common.hook_handle import HookHandle
|
|
20
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
+
|
|
17
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
|
|
25
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
+
from msprobe.mindspore.common.log import logger
|
|
27
|
+
from msprobe.mindspore.common.utils import (
|
|
28
|
+
is_mindtorch,
|
|
29
|
+
get_cells_and_names_with_index,
|
|
30
|
+
has_kwargs_in_forward_hook,
|
|
31
|
+
is_graph_mode_cell_dump_allowed
|
|
32
|
+
)
|
|
33
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
34
|
+
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
35
|
+
from msprobe.core.common.runtime import Runtime
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cell_construct(construct):
|
|
39
|
+
def _construct(self, *args, **kwargs):
|
|
40
|
+
if hasattr(self, 'msprobe_hook'):
|
|
41
|
+
setattr(self, 'msprobe_input_kwargs', kwargs)
|
|
42
|
+
return construct(self, *args, **kwargs)
|
|
43
|
+
return _construct
|
|
18
44
|
|
|
19
45
|
|
|
20
46
|
class CellProcessor:
|
|
21
47
|
cell_count = {}
|
|
22
48
|
cell_stack = []
|
|
23
|
-
api_parent_node =
|
|
49
|
+
api_parent_node = None
|
|
24
50
|
module_node = {}
|
|
51
|
+
cell_bw_hook_kernels = {}
|
|
52
|
+
cell_backward_pre_hook = []
|
|
53
|
+
cell_backward_hook = []
|
|
25
54
|
|
|
26
55
|
def __init__(self, scope):
|
|
27
56
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
28
57
|
|
|
29
58
|
@staticmethod
|
|
30
|
-
def
|
|
59
|
+
def set_and_get_calls_number(cell_name):
|
|
31
60
|
if cell_name not in CellProcessor.cell_count:
|
|
32
61
|
CellProcessor.cell_count[cell_name] = 0
|
|
33
62
|
else:
|
|
@@ -38,42 +67,184 @@ class CellProcessor:
|
|
|
38
67
|
def reset_cell_stats(cls):
|
|
39
68
|
cls.cell_count = {}
|
|
40
69
|
cls.cell_stack = []
|
|
41
|
-
cls.api_parent_node =
|
|
70
|
+
cls.api_parent_node = None
|
|
42
71
|
cls.module_node = {}
|
|
72
|
+
cls.cell_bw_hook_kernels = {}
|
|
73
|
+
cls.cell_backward_pre_hook = []
|
|
74
|
+
cls.cell_backward_hook = []
|
|
43
75
|
|
|
44
|
-
def
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
76
|
+
def register_cell_hook(self, models, build_hook, config: DebuggerConfig):
|
|
77
|
+
if not models:
|
|
78
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
79
|
+
'The model cannot be None, when level is "L0" or "mix"')
|
|
80
|
+
|
|
81
|
+
is_registered = False
|
|
82
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
83
|
+
cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
|
|
84
|
+
construct_name = '_call_impl' if is_mindtorch() else '_run_construct'
|
|
85
|
+
|
|
86
|
+
for index, cells_and_names in cells_with_index_in_pynative_mode.items():
|
|
87
|
+
model = models if index == "-1" else models[int(index)]
|
|
88
|
+
for name, cell in cells_and_names:
|
|
89
|
+
if cell == model:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if not has_kwargs_in_forward_hook():
|
|
93
|
+
if not hasattr(cell.__class__, 'msprobe_construct'):
|
|
94
|
+
setattr(cell.__class__, 'msprobe_construct', True)
|
|
95
|
+
if hasattr(cell.__class__, construct_name):
|
|
96
|
+
setattr(cell.__class__, construct_name,
|
|
97
|
+
get_cell_construct(getattr(cell.__class__, construct_name)))
|
|
98
|
+
setattr(cell, 'msprobe_hook', True)
|
|
99
|
+
|
|
100
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
101
|
+
prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}'
|
|
102
|
+
|
|
103
|
+
forward_pre_hook = self.build_cell_hook(prefix, build_hook)
|
|
104
|
+
cell.register_forward_pre_hook(forward_pre_hook)
|
|
105
|
+
|
|
106
|
+
if not is_registered:
|
|
107
|
+
logger.info("The cell hook function is successfully mounted to the model.")
|
|
108
|
+
is_registered = True
|
|
109
|
+
|
|
110
|
+
if is_graph_mode_cell_dump_allowed(config):
|
|
111
|
+
cells_and_names_in_graph_mode = []
|
|
112
|
+
for index, cells_and_names in cells_with_index_in_graph_mode.items():
|
|
113
|
+
model = models if index == "-1" else models[int(index)]
|
|
114
|
+
for name, cell in cells_and_names:
|
|
115
|
+
if cell == model:
|
|
116
|
+
continue
|
|
117
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
118
|
+
cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
|
|
119
|
+
|
|
120
|
+
if cells_and_names_in_graph_mode:
|
|
121
|
+
Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
|
|
122
|
+
GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
|
|
51
123
|
|
|
52
|
-
|
|
53
|
-
|
|
124
|
+
def build_cell_hook(self, cell_name, build_data_hook):
|
|
125
|
+
def forward_pre_hook(cell, args):
|
|
126
|
+
index = CellProcessor.set_and_get_calls_number(cell_name)
|
|
127
|
+
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
128
|
+
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
54
129
|
|
|
55
|
-
|
|
56
|
-
self.scope.begin_module(full_name)
|
|
130
|
+
self.set_construct_info_in_pre_hook(full_forward_name)
|
|
57
131
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
132
|
+
if not hasattr(cell, 'msprobe_forward_hook'):
|
|
133
|
+
if is_mindtorch():
|
|
134
|
+
cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True)
|
|
135
|
+
else:
|
|
136
|
+
forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict())
|
|
137
|
+
if has_kwargs_in_forward_hook():
|
|
138
|
+
forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict())
|
|
139
|
+
handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict)
|
|
140
|
+
forward_hook_with_kwargs_dict[handle.handle_id] = True
|
|
141
|
+
else:
|
|
142
|
+
handle = HookHandle(forward_hook_dict)
|
|
143
|
+
forward_hook_dict[handle.handle_id] = forward_hook
|
|
144
|
+
forward_hook_dict.move_to_end(handle.handle_id, last=False)
|
|
145
|
+
|
|
146
|
+
setattr(cell, 'msprobe_forward_hook', True)
|
|
147
|
+
|
|
148
|
+
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
149
|
+
def backward_hook_fn(cell, grad_input, grad_output):
|
|
150
|
+
new_output = backward_data_hook(cell, grad_input, grad_output)
|
|
151
|
+
self.set_construct_info_in_hook(full_backward_name)
|
|
152
|
+
cell.has_pre_hook_called = False
|
|
153
|
+
return new_output
|
|
154
|
+
return backward_hook_fn
|
|
155
|
+
|
|
156
|
+
enable_hooked = sum(
|
|
157
|
+
[isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args]
|
|
158
|
+
)
|
|
159
|
+
if enable_hooked:
|
|
160
|
+
backward_hook = OrderedDict()
|
|
161
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
162
|
+
backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name)
|
|
163
|
+
CellProcessor.cell_backward_hook.append(backward_hook)
|
|
164
|
+
bw_hook = inner.CellBackwardHook(full_backward_name, cell,
|
|
165
|
+
self.cell_backward_hook[-1])
|
|
166
|
+
bw_hook.register_backward_hook()
|
|
167
|
+
CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
|
|
168
|
+
|
|
169
|
+
args = bw_hook(*args)
|
|
170
|
+
|
|
171
|
+
return args
|
|
172
|
+
|
|
173
|
+
def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
|
|
174
|
+
index = CellProcessor.cell_count.get(cell_name, 0)
|
|
175
|
+
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
176
|
+
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
177
|
+
|
|
178
|
+
self.set_construct_info_in_hook(full_forward_name)
|
|
179
|
+
|
|
180
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
181
|
+
hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
|
|
182
|
+
if hook_result is not None:
|
|
183
|
+
outputs = hook_result
|
|
63
184
|
else:
|
|
64
|
-
|
|
185
|
+
outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output
|
|
186
|
+
|
|
187
|
+
bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name)
|
|
188
|
+
if bw_hook:
|
|
189
|
+
if not isinstance(outputs, (Tensor, tuple)):
|
|
190
|
+
logger.warning("For backward hooks to be called,"
|
|
191
|
+
" cell output should be a Tensor or a tuple of Tensors"
|
|
192
|
+
f" but received {type(outputs)}")
|
|
193
|
+
if isinstance(outputs, tuple):
|
|
194
|
+
new_outputs = bw_hook(*outputs)
|
|
195
|
+
else:
|
|
196
|
+
new_outputs = bw_hook(outputs)
|
|
197
|
+
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
198
|
+
new_outputs = (new_outputs,)
|
|
199
|
+
outputs = new_outputs
|
|
200
|
+
|
|
201
|
+
def get_backward_pre_hook(full_backward_name, backward_data_hook):
|
|
202
|
+
def backward_pre_hook_fn(cell, grad_output):
|
|
203
|
+
cell.has_pre_hook_called = True
|
|
204
|
+
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
205
|
+
if backward_data_hook:
|
|
206
|
+
backward_data_hook(cell, (), grad_output)
|
|
207
|
+
self.set_construct_info_in_hook(full_backward_name)
|
|
208
|
+
cell.has_pre_hook_called = False
|
|
209
|
+
return backward_pre_hook_fn
|
|
65
210
|
|
|
66
|
-
|
|
67
|
-
|
|
211
|
+
backward_pre_hook = OrderedDict()
|
|
212
|
+
backward_data_hook = None if bw_hook else hook_set.backward_hook
|
|
213
|
+
backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook)
|
|
214
|
+
CellProcessor.cell_backward_pre_hook.append(backward_pre_hook)
|
|
215
|
+
bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell,
|
|
216
|
+
self.cell_backward_pre_hook[-1])
|
|
217
|
+
bw_pre_hook.register_backward_pre_hook()
|
|
68
218
|
|
|
69
|
-
|
|
219
|
+
if isinstance(outputs, tuple):
|
|
220
|
+
result = bw_pre_hook(*outputs)
|
|
221
|
+
else:
|
|
222
|
+
result = bw_pre_hook(outputs)
|
|
223
|
+
if isinstance(outputs, tuple):
|
|
224
|
+
if len(outputs) == 1:
|
|
225
|
+
result = (result,)
|
|
226
|
+
if len(result) != len(outputs):
|
|
227
|
+
raise TypeError(
|
|
228
|
+
f"The backward pre hook return value size is {len(result)} "
|
|
229
|
+
f"not equal to output size {len(outputs)}"
|
|
230
|
+
)
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
return forward_pre_hook
|
|
70
234
|
|
|
71
|
-
def
|
|
72
|
-
if
|
|
73
|
-
|
|
235
|
+
def set_construct_info_in_pre_hook(self, full_name):
|
|
236
|
+
if self.cell_stack:
|
|
237
|
+
CellProcessor.module_node[full_name] = self.cell_stack[-1]
|
|
74
238
|
else:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
239
|
+
CellProcessor.module_node[full_name] = None
|
|
240
|
+
CellProcessor.cell_stack.append(full_name)
|
|
241
|
+
CellProcessor.api_parent_node = full_name
|
|
242
|
+
if self.scope:
|
|
243
|
+
self.scope.begin_module(full_name)
|
|
244
|
+
|
|
245
|
+
def set_construct_info_in_hook(self, full_name):
|
|
246
|
+
if self.cell_stack:
|
|
247
|
+
CellProcessor.cell_stack.pop()
|
|
248
|
+
CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None
|
|
249
|
+
if self.scope:
|
|
250
|
+
self.scope.end_module(full_name)
|
|
@@ -34,19 +34,6 @@ class Parser:
|
|
|
34
34
|
if isinstance(subgraph_node.attrs, list):
|
|
35
35
|
subgraph_node.attrs.extend(attrs)
|
|
36
36
|
|
|
37
|
-
@staticmethod
|
|
38
|
-
def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
|
|
39
|
-
attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
|
|
40
|
-
match = attr_pattern.search(text, graph_node.pos)
|
|
41
|
-
if match:
|
|
42
|
-
attrs = match.group(1).strip().split('\n')
|
|
43
|
-
for attr in attrs:
|
|
44
|
-
if not attr:
|
|
45
|
-
break
|
|
46
|
-
key, value = attr.split(':')
|
|
47
|
-
if isinstance(graph_node.attrs, dict):
|
|
48
|
-
graph_node.attrs[key.strip()] = value.strip()
|
|
49
|
-
|
|
50
37
|
@staticmethod
|
|
51
38
|
def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
|
|
52
39
|
code_info = []
|
|
@@ -124,8 +111,9 @@ class Parser:
|
|
|
124
111
|
scope_match = scope_pattern.search(text, end_pos)
|
|
125
112
|
scope = scope_match.group(1) if scope_match else ""
|
|
126
113
|
|
|
127
|
-
id_pattern = re.compile(
|
|
128
|
-
|
|
114
|
+
id_pattern = re.compile(
|
|
115
|
+
r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"',
|
|
116
|
+
re.IGNORECASE)
|
|
129
117
|
unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
|
|
130
118
|
unique_id = unique_id_match.group(1) if unique_id_match else None
|
|
131
119
|
|
|
@@ -186,7 +174,7 @@ class Parser:
|
|
|
186
174
|
node_info.var_inputs.append(callee_name)
|
|
187
175
|
|
|
188
176
|
def parse_subgraphs(self, text: str) -> None:
|
|
189
|
-
subgraph_pattern = re.compile(r'subgraph\s+@(\
|
|
177
|
+
subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+')
|
|
190
178
|
matches = list(subgraph_pattern.finditer(text))
|
|
191
179
|
end_pos = 0
|
|
192
180
|
for match in matches:
|
|
@@ -203,11 +191,6 @@ class Parser:
|
|
|
203
191
|
subgraph_info.end = end_pos
|
|
204
192
|
logging.info('Parsed subgraph: %s', subgraph_name)
|
|
205
193
|
|
|
206
|
-
def count_nodes(self) -> Tuple[int, int]:
|
|
207
|
-
total_nodes = len(self.nodes)
|
|
208
|
-
total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
|
|
209
|
-
return total_nodes, total_cnodes
|
|
210
|
-
|
|
211
194
|
def create_backward_map(self):
|
|
212
195
|
for node in self.nodes.values():
|
|
213
196
|
if node.scope and node.scope.startswith("Gradients"):
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
import mindspore as ms
|
|
18
|
+
from mindspore import dtype as mstype
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const as CoreConst
|
|
20
21
|
|
|
@@ -23,14 +24,20 @@ class Const:
|
|
|
23
24
|
CELL = "cell"
|
|
24
25
|
API = "api"
|
|
25
26
|
KERNEL = "kernel"
|
|
27
|
+
CELL_AND_API = 'cell_and_api'
|
|
26
28
|
TOOL_LEVEL_DICT = {
|
|
27
29
|
CoreConst.LEVEL_L0: CELL,
|
|
28
30
|
CoreConst.LEVEL_L1: API,
|
|
29
|
-
CoreConst.LEVEL_L2: KERNEL
|
|
31
|
+
CoreConst.LEVEL_L2: KERNEL,
|
|
32
|
+
CoreConst.LEVEL_MIX: CELL_AND_API
|
|
30
33
|
}
|
|
31
|
-
|
|
34
|
+
|
|
35
|
+
PYNATIVE_MODE = CoreConst.PYNATIVE_MODE
|
|
36
|
+
GRAPH_MODE = "graph"
|
|
32
37
|
GRAPH_GE_MODE = "graph_ge"
|
|
33
38
|
GRAPH_KBYK_MODE = "graph_kbyk"
|
|
39
|
+
PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE
|
|
40
|
+
|
|
34
41
|
JIT_LEVEL = "jit_level"
|
|
35
42
|
JIT_LEVEL_O0 = "O0"
|
|
36
43
|
JIT_LEVEL_O1 = "O1"
|
|
@@ -61,6 +68,7 @@ class Const:
|
|
|
61
68
|
DROPOUT_API_NAME_PREFIX = "dropout"
|
|
62
69
|
|
|
63
70
|
GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
|
|
71
|
+
GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD]
|
|
64
72
|
|
|
65
73
|
HOOK_MS_PREFIX_DICT = {
|
|
66
74
|
OPS_DATA_PREFIX: OPS_PREFIX,
|
|
@@ -69,6 +77,69 @@ class Const:
|
|
|
69
77
|
MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
|
|
70
78
|
}
|
|
71
79
|
|
|
80
|
+
NonDifferentiableType = (
|
|
81
|
+
mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte,
|
|
82
|
+
mstype.int16, mstype.short, mstype.uint16, mstype.ushort,
|
|
83
|
+
mstype.int32, mstype.intc, mstype.uint32, mstype.uintc,
|
|
84
|
+
mstype.int64, mstype.intp, mstype.uint64, mstype.uintp
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class MsCompareConst:
|
|
89
|
+
# api_info field
|
|
90
|
+
MINT = "Mint"
|
|
91
|
+
MINT_FUNCTIONAL = "MintFunctional"
|
|
92
|
+
TENSOR_API = "Tensor"
|
|
93
|
+
FUNCTIONAL_API = "Functional"
|
|
94
|
+
FUSION_API = "FUSION"
|
|
95
|
+
|
|
96
|
+
API_NAME_STR_LENGTH = 4
|
|
97
|
+
MAX_RECURSION_DEPTH = 20
|
|
98
|
+
|
|
99
|
+
# Mindtorch api_info field
|
|
100
|
+
MINDTORCH_TENSOR = "Tensor"
|
|
101
|
+
MINDTORCH = "Torch"
|
|
102
|
+
MINDTORCH_FUNC = "Functional"
|
|
103
|
+
MINDTORCH_NPU = "NPU"
|
|
104
|
+
MINDTORCH_DIST = "Distributed"
|
|
105
|
+
|
|
106
|
+
MT_VALID_API_TYPES = [
|
|
107
|
+
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
108
|
+
]
|
|
109
|
+
SUPPORTED_FUSION_LIST = ["flash_attention_score"]
|
|
110
|
+
|
|
111
|
+
TASK_FIELD = "task"
|
|
112
|
+
STATISTICS_TASK = "statistics"
|
|
113
|
+
FRAMEWORK = "framework"
|
|
114
|
+
TENSOR_TASK = "tensor"
|
|
115
|
+
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
116
|
+
DATA_FIELD = "data"
|
|
117
|
+
|
|
118
|
+
# supported api yaml
|
|
119
|
+
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
120
|
+
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
121
|
+
|
|
122
|
+
# detail_csv
|
|
123
|
+
DETAIL_CSV_API_NAME = "API Name"
|
|
124
|
+
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
125
|
+
DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
|
|
126
|
+
DETAIL_CSV_SHAPE = "Shape"
|
|
127
|
+
DETAIL_CSV_PASS_STATUS = "Status"
|
|
128
|
+
DETAIL_CSV_MESSAGE = "Message"
|
|
129
|
+
DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
|
|
130
|
+
|
|
131
|
+
# result_csv
|
|
132
|
+
RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
|
|
133
|
+
RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
|
|
134
|
+
RESULT_CSV_FILE_NAME = "accuracy_checking_result"
|
|
135
|
+
|
|
136
|
+
EPSILON = 1e-8
|
|
137
|
+
|
|
138
|
+
class ProcessStatus:
|
|
139
|
+
SUCCESS = "success"
|
|
140
|
+
API_NOT_FOUND = "api_not_found"
|
|
141
|
+
EXCEPTION_SKIP = "exception_skip"
|
|
142
|
+
|
|
72
143
|
|
|
73
144
|
class FreeBenchmarkConst:
|
|
74
145
|
ADD_NOISE = "add_noise"
|