mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,15 +14,20 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
from collections import defaultdict
|
|
17
|
+
from collections import defaultdict, namedtuple
|
|
18
18
|
|
|
19
19
|
import mindspore as ms
|
|
20
20
|
from mindspore._c_expression import MSContext
|
|
21
21
|
|
|
22
|
-
from msprobe.core.common.const import Const, MsgConst
|
|
22
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
25
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
26
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
24
27
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
28
|
+
from msprobe.mindspore.common.utils import set_register_backward_hook_functions
|
|
25
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
26
31
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
27
32
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
28
33
|
from msprobe.mindspore.ms_config import parse_json_config
|
|
@@ -30,12 +35,21 @@ from msprobe.mindspore.runtime import Runtime
|
|
|
30
35
|
from msprobe.mindspore.service import Service
|
|
31
36
|
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
32
37
|
|
|
38
|
+
try:
|
|
39
|
+
from msprobe.lib import _msprobe_c
|
|
40
|
+
except ImportError:
|
|
41
|
+
_msprobe_c = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"])
|
|
45
|
+
|
|
33
46
|
|
|
34
47
|
class PrecisionDebugger:
|
|
35
48
|
_instance = None
|
|
36
49
|
task_not_need_service = [Const.GRAD_PROBE]
|
|
37
50
|
|
|
38
|
-
def __new__(cls, config_path=None,
|
|
51
|
+
def __new__(cls, config_path=None, task=None, dump_path=None,
|
|
52
|
+
level=None, step=None, opt=None):
|
|
39
53
|
if not cls._instance:
|
|
40
54
|
cls._instance = super().__new__(cls)
|
|
41
55
|
cls._instance.initialized = False
|
|
@@ -44,22 +58,65 @@ class PrecisionDebugger:
|
|
|
44
58
|
cls.first_start = False
|
|
45
59
|
return cls._instance
|
|
46
60
|
|
|
47
|
-
def __init__(self, config_path=None
|
|
61
|
+
def __init__(self, config_path=None, task=None, dump_path=None,
|
|
62
|
+
level=None, step=None):
|
|
48
63
|
if self.initialized:
|
|
49
64
|
return
|
|
50
65
|
self.initialized = True
|
|
66
|
+
|
|
67
|
+
set_register_backward_hook_functions()
|
|
68
|
+
|
|
51
69
|
if not config_path:
|
|
52
70
|
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
|
71
|
+
|
|
72
|
+
config_params = ConfigParameters(config_path, task, dump_path, level)
|
|
73
|
+
self.check_input_params(config_params)
|
|
74
|
+
|
|
53
75
|
common_config, task_config = parse_json_config(config_path)
|
|
76
|
+
common_config.task = task if task else common_config.task
|
|
54
77
|
self.task = common_config.task
|
|
55
78
|
if self.task == Const.GRAD_PROBE:
|
|
56
79
|
self.gm = GradientMonitor(common_config, task_config)
|
|
57
80
|
return
|
|
81
|
+
common_config.step = get_real_step_or_rank(
|
|
82
|
+
step, Const.STEP) if step is not None else common_config.step
|
|
83
|
+
common_config.level = level if level else common_config.level
|
|
84
|
+
common_config.dump_path = dump_path if dump_path else common_config.dump_path
|
|
58
85
|
self.config = DebuggerConfig(common_config, task_config)
|
|
59
86
|
|
|
87
|
+
if _msprobe_c:
|
|
88
|
+
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
89
|
+
|
|
90
|
+
self.config.execution_mode = self._get_execution_mode()
|
|
91
|
+
if self._need_service():
|
|
92
|
+
self.service = Service(self.config)
|
|
93
|
+
|
|
60
94
|
Runtime.step_count = 0
|
|
61
95
|
Runtime.is_running = False
|
|
62
96
|
|
|
97
|
+
@staticmethod
|
|
98
|
+
def check_input_params(args):
|
|
99
|
+
if args.config_path is not None:
|
|
100
|
+
if not isinstance(args.config_path, str):
|
|
101
|
+
raise MsprobeException(
|
|
102
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
103
|
+
file_checker = FileChecker(
|
|
104
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
105
|
+
file_checker.common_check()
|
|
106
|
+
|
|
107
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
108
|
+
raise MsprobeException(
|
|
109
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
110
|
+
|
|
111
|
+
if args.dump_path is not None:
|
|
112
|
+
if not isinstance(args.dump_path, str):
|
|
113
|
+
raise MsprobeException(
|
|
114
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
115
|
+
|
|
116
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
117
|
+
raise MsprobeException(
|
|
118
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
119
|
+
|
|
63
120
|
@staticmethod
|
|
64
121
|
def _get_execution_mode():
|
|
65
122
|
jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL)
|
|
@@ -78,11 +135,23 @@ class PrecisionDebugger:
|
|
|
78
135
|
else:
|
|
79
136
|
return MsConst.PYNATIVE_MODE
|
|
80
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _is_graph_dump(config):
|
|
140
|
+
if config.level != MsConst.KERNEL:
|
|
141
|
+
return False
|
|
142
|
+
if not config.list or len(config.list) > 1:
|
|
143
|
+
return True
|
|
144
|
+
if '-' in config.list[0] or '/' in config.list[0]:
|
|
145
|
+
return True
|
|
146
|
+
return False
|
|
147
|
+
|
|
81
148
|
@classmethod
|
|
82
149
|
def start(cls, model=None):
|
|
83
150
|
instance = cls._instance
|
|
84
151
|
if not instance:
|
|
85
152
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
153
|
+
if _msprobe_c:
|
|
154
|
+
_msprobe_c._PrecisionDebugger().start()
|
|
86
155
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
87
156
|
return
|
|
88
157
|
|
|
@@ -93,6 +162,7 @@ class PrecisionDebugger:
|
|
|
93
162
|
instance.service.start(model)
|
|
94
163
|
else:
|
|
95
164
|
if not instance.first_start:
|
|
165
|
+
api_register.api_set_ori_func()
|
|
96
166
|
handler = TaskHandlerFactory.create(instance.config)
|
|
97
167
|
handler.handle()
|
|
98
168
|
|
|
@@ -102,18 +172,15 @@ class PrecisionDebugger:
|
|
|
102
172
|
@classmethod
|
|
103
173
|
def forward_backward_dump_end(cls):
|
|
104
174
|
instance = cls._instance
|
|
105
|
-
|
|
106
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
107
|
-
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
108
|
-
return
|
|
109
|
-
if instance.service:
|
|
110
|
-
instance.service.forward_backward_dump_end()
|
|
175
|
+
instance.stop()
|
|
111
176
|
|
|
112
177
|
@classmethod
|
|
113
178
|
def stop(cls):
|
|
114
179
|
instance = cls._instance
|
|
115
180
|
if not instance:
|
|
116
181
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
182
|
+
if _msprobe_c:
|
|
183
|
+
_msprobe_c._PrecisionDebugger().stop()
|
|
117
184
|
if instance.task == Const.GRAD_PROBE:
|
|
118
185
|
instance.gm.stop()
|
|
119
186
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
@@ -127,6 +194,8 @@ class PrecisionDebugger:
|
|
|
127
194
|
instance = cls._instance
|
|
128
195
|
if not instance:
|
|
129
196
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
197
|
+
if _msprobe_c:
|
|
198
|
+
_msprobe_c._PrecisionDebugger().step()
|
|
130
199
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
131
200
|
return
|
|
132
201
|
if instance.service:
|
|
@@ -153,4 +222,4 @@ class PrecisionDebugger:
|
|
|
153
222
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
154
223
|
return False
|
|
155
224
|
else:
|
|
156
|
-
return instance.config.task != Const.FREE_BENCHMARK and instance.config
|
|
225
|
+
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,7 +12,6 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from mindspore import Tensor, ops, mint
|
|
17
17
|
from mindspore.mint.nn import functional
|
|
@@ -20,8 +20,15 @@ from mindspore.communication import comm_func
|
|
|
20
20
|
|
|
21
21
|
from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
|
|
22
22
|
HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
|
|
23
|
-
|
|
23
|
+
HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
|
|
24
|
+
HOOKTorchDistributedOP, HOOKTorchNpuOP,
|
|
25
|
+
get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
|
|
24
26
|
from msprobe.core.common.utils import Const
|
|
27
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
28
|
+
|
|
29
|
+
if is_mindtorch():
|
|
30
|
+
import torch
|
|
31
|
+
import torch_npu
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
def stub_method(method):
|
|
@@ -40,6 +47,12 @@ class ApiRegistry:
|
|
|
40
47
|
self.distributed_ori_attr = {}
|
|
41
48
|
self.norm_inner_ops_ori_attr = {}
|
|
42
49
|
|
|
50
|
+
self.torch_ori_attr = {}
|
|
51
|
+
self.torch_tensor_ori_attr = {}
|
|
52
|
+
self.torch_functional_ori_attr = {}
|
|
53
|
+
self.torch_distributed_ori_attr = {}
|
|
54
|
+
self.torch_npu_ori_attr = {}
|
|
55
|
+
|
|
43
56
|
self.tensor_hook_attr = {}
|
|
44
57
|
self.stub_tensor_hook_attr = {}
|
|
45
58
|
self.functional_hook_attr = {}
|
|
@@ -48,6 +61,12 @@ class ApiRegistry:
|
|
|
48
61
|
self.distibuted_hook_attr = {}
|
|
49
62
|
self.norm_inner_ops_hook_attr = {}
|
|
50
63
|
|
|
64
|
+
self.torch_hook_attr = {}
|
|
65
|
+
self.torch_tensor_hook_attr = {}
|
|
66
|
+
self.torch_functional_hook_attr = {}
|
|
67
|
+
self.torch_distributed_hook_attr = {}
|
|
68
|
+
self.torch_npu_hook_attr = {}
|
|
69
|
+
|
|
51
70
|
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
52
71
|
|
|
53
72
|
@staticmethod
|
|
@@ -82,22 +101,71 @@ class ApiRegistry:
|
|
|
82
101
|
self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
|
|
83
102
|
|
|
84
103
|
def api_set_hook_func(self):
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
104
|
+
if is_mindtorch():
|
|
105
|
+
self.set_api_attr(torch, self.torch_hook_attr)
|
|
106
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
|
|
107
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
|
|
108
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
|
|
109
|
+
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
110
|
+
else:
|
|
111
|
+
self.set_api_attr(Tensor, self.tensor_hook_attr)
|
|
112
|
+
self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
|
|
113
|
+
self.set_api_attr(ops, self.functional_hook_attr)
|
|
114
|
+
self.set_api_attr(mint, self.mint_ops_hook_attr)
|
|
115
|
+
self.set_api_attr(functional, self.mint_func_ops_hook_attr)
|
|
116
|
+
self.set_api_attr(comm_func, self.distibuted_hook_attr)
|
|
91
117
|
|
|
92
118
|
def api_set_ori_func(self):
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
119
|
+
if is_mindtorch():
|
|
120
|
+
self.set_api_attr(torch, self.torch_ori_attr)
|
|
121
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
|
|
122
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
|
|
123
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
|
|
124
|
+
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
125
|
+
else:
|
|
126
|
+
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
127
|
+
self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
|
|
128
|
+
self.set_api_attr(ops, self.functional_ori_attr)
|
|
129
|
+
self.set_api_attr(mint, self.mint_ops_ori_attr)
|
|
130
|
+
self.set_api_attr(functional, self.mint_func_ops_ori_attr)
|
|
131
|
+
self.set_api_attr(comm_func, self.distributed_ori_attr)
|
|
99
132
|
|
|
100
133
|
def initialize_hook(self, hook):
|
|
134
|
+
setup_hooks(hook)
|
|
135
|
+
if is_mindtorch():
|
|
136
|
+
wrap_torch_api_name = get_wrap_torch_api_list()
|
|
137
|
+
self.store_ori_attr(torch,
|
|
138
|
+
wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
|
|
139
|
+
self.store_ori_attr(torch.Tensor,
|
|
140
|
+
wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
|
|
141
|
+
self.store_ori_attr(torch.nn.functional,
|
|
142
|
+
wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
|
|
143
|
+
self.store_ori_attr(torch.distributed,
|
|
144
|
+
wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
|
|
145
|
+
self.store_ori_attr(torch_npu,
|
|
146
|
+
wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
|
|
147
|
+
for attr_name in dir(HOOKTorchOP):
|
|
148
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
149
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
150
|
+
self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
|
|
151
|
+
for attr_name in dir(HOOKTorchTensor):
|
|
152
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
153
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
154
|
+
self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
|
|
155
|
+
for attr_name in dir(HOOKTorchFunctionalOP):
|
|
156
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
157
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
158
|
+
self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
|
|
159
|
+
for attr_name in dir(HOOKTorchDistributedOP):
|
|
160
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
161
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
162
|
+
self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
|
|
163
|
+
for attr_name in dir(HOOKTorchNpuOP):
|
|
164
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
165
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
166
|
+
self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
|
|
167
|
+
return
|
|
168
|
+
|
|
101
169
|
wrap_api_name = get_wrap_api_list()
|
|
102
170
|
self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
|
|
103
171
|
self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
|
|
@@ -106,7 +174,6 @@ class ApiRegistry:
|
|
|
106
174
|
self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
|
|
107
175
|
self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
|
|
108
176
|
self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
109
|
-
setup_hooks(hook)
|
|
110
177
|
for attr_name in dir(HOOKTensor):
|
|
111
178
|
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
112
179
|
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,45 +12,66 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from collections import defaultdict
|
|
17
17
|
|
|
18
18
|
from mindspore import nn
|
|
19
19
|
|
|
20
|
-
from msprobe.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
cell_count
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
20
|
+
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def add_cell_count(name):
|
|
24
|
+
HOOKCell.cell_count[name] += 1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_cell_count(name):
|
|
28
|
+
return HOOKCell.cell_count[name]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __init__(self, build_hook) -> None:
|
|
32
|
+
super(HOOKCell, self).__init__()
|
|
33
|
+
self.changed_status = False
|
|
34
|
+
self.input_kwargs = {}
|
|
35
|
+
self.prefix = ""
|
|
36
|
+
if not HOOKCell.g_stop_hook:
|
|
37
|
+
HOOKCell.g_stop_hook = True
|
|
38
|
+
self.changed_status = True
|
|
39
|
+
if hasattr(self, "prefix_api_name"):
|
|
40
|
+
self.prefix = self.prefix_api_name
|
|
41
|
+
|
|
42
|
+
self.forward_data_collected = False
|
|
43
|
+
forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
|
|
44
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
45
|
+
self.register_forward_hook(forward_hook)
|
|
46
|
+
register_backward_hook_functions["full"](self, backward_hook)
|
|
47
|
+
register_backward_hook_functions["pre"](self, backward_pre_hook)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# 重载call,加全局标志。
|
|
51
|
+
def __call__(self, *args, **kwargs):
|
|
52
|
+
try:
|
|
53
|
+
self.input_kwargs = kwargs
|
|
54
|
+
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise e
|
|
57
|
+
finally:
|
|
58
|
+
if self.changed_status:
|
|
59
|
+
self.changed_status = False
|
|
60
|
+
HOOKCell.g_stop_hook = False
|
|
61
|
+
return out
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
hook_cell_dict = {
|
|
65
|
+
"cell_count": defaultdict(int),
|
|
66
|
+
"g_stop_hook": False,
|
|
67
|
+
"add_cell_count": staticmethod(add_cell_count),
|
|
68
|
+
"get_cell_count": staticmethod(get_cell_count),
|
|
69
|
+
"__init__": __init__,
|
|
70
|
+
"__call__": __call__
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if is_mindtorch():
|
|
74
|
+
import torch
|
|
75
|
+
HOOKCell = type("HOOKCell", (torch.nn.Module,), hook_cell_dict)
|
|
76
|
+
else:
|
|
77
|
+
HOOKCell = type("HOOKCell", (nn.Cell,), hook_cell_dict)
|
|
@@ -135,6 +135,34 @@ class PrimitiveHookService:
|
|
|
135
135
|
return tuple(hooked_outputs)
|
|
136
136
|
return out
|
|
137
137
|
|
|
138
|
+
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
139
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
140
|
+
try:
|
|
141
|
+
self.service_instance.data_collector.forward_input_data_collect(
|
|
142
|
+
primitive_name,
|
|
143
|
+
primitive_instance,
|
|
144
|
+
os.getpid(),
|
|
145
|
+
module_input_output
|
|
146
|
+
)
|
|
147
|
+
except Exception as exception:
|
|
148
|
+
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
149
|
+
f"primitive_name: {primitive_name}")
|
|
150
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
151
|
+
|
|
152
|
+
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
153
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
154
|
+
try:
|
|
155
|
+
self.service_instance.data_collector.forward_output_data_collect(
|
|
156
|
+
primitive_name,
|
|
157
|
+
primitive_instance,
|
|
158
|
+
os.getpid(),
|
|
159
|
+
module_input_output
|
|
160
|
+
)
|
|
161
|
+
except Exception as exception:
|
|
162
|
+
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
163
|
+
f"primitive_name: {primitive_name}")
|
|
164
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
165
|
+
|
|
138
166
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
139
167
|
"""
|
|
140
168
|
包装后的 primitive 调用函数,添加输入和输出的 hook。
|
|
@@ -163,27 +191,17 @@ class PrimitiveHookService:
|
|
|
163
191
|
f"primitive_name: {primitive_name}")
|
|
164
192
|
raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
|
|
165
193
|
|
|
194
|
+
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
195
|
+
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
196
|
+
|
|
197
|
+
pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
|
|
166
198
|
try:
|
|
167
199
|
out = origin_func(*hooked_inputs, **kwargs)
|
|
168
200
|
except Exception as exception:
|
|
169
201
|
logger.error(f"This is a primitive op dump error during function call: {exception}, "
|
|
170
202
|
f"primitive_name: {primitive_name}")
|
|
171
203
|
raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
|
|
172
|
-
|
|
173
|
-
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
174
|
-
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
175
|
-
if self.service_instance.data_collector:
|
|
176
|
-
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
177
|
-
try:
|
|
178
|
-
self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
179
|
-
os.getpid(), module_input_output)
|
|
180
|
-
except Exception as exception:
|
|
181
|
-
logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
|
|
182
|
-
f"primitive_name: {primitive_name}")
|
|
183
|
-
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
184
|
-
|
|
185
|
-
if self.service_instance.data_collector.if_return_forward_new_output():
|
|
186
|
-
out = self.service_instance.data_collector.get_forward_new_output()
|
|
204
|
+
post_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs, out)
|
|
187
205
|
|
|
188
206
|
try:
|
|
189
207
|
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
# List of ops that register hooks
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
ops:
|
|
20
20
|
- adaptive_avg_pool1d
|
|
21
21
|
- adaptive_avg_pool2d
|
|
@@ -85,6 +85,7 @@ ops:
|
|
|
85
85
|
- relu6
|
|
86
86
|
- celu
|
|
87
87
|
- rrelu
|
|
88
|
+
- rms_norm
|
|
88
89
|
- selu
|
|
89
90
|
- sigmoid
|
|
90
91
|
- silu
|
|
@@ -553,6 +554,7 @@ tensor:
|
|
|
553
554
|
- acos
|
|
554
555
|
- acosh
|
|
555
556
|
- add
|
|
557
|
+
- add_
|
|
556
558
|
- addbmm
|
|
557
559
|
- addcdiv
|
|
558
560
|
- addcmul
|
|
@@ -607,6 +609,7 @@ tensor:
|
|
|
607
609
|
- diff
|
|
608
610
|
- digamma
|
|
609
611
|
- div
|
|
612
|
+
- div_
|
|
610
613
|
- divide
|
|
611
614
|
- equal
|
|
612
615
|
- erf
|
|
@@ -739,6 +742,8 @@ tensor:
|
|
|
739
742
|
- square
|
|
740
743
|
- squeeze
|
|
741
744
|
- std
|
|
745
|
+
- sub
|
|
746
|
+
- sub_
|
|
742
747
|
- subtract
|
|
743
748
|
- subtract
|
|
744
749
|
- svd
|
|
@@ -983,6 +988,7 @@ mint.nn.functional:
|
|
|
983
988
|
- one_hot_ext
|
|
984
989
|
- pad
|
|
985
990
|
- relu
|
|
991
|
+
- relu_
|
|
986
992
|
- sigmoid
|
|
987
993
|
- silu
|
|
988
994
|
- softmax
|
|
@@ -1017,3 +1023,7 @@ communication.comm_func:
|
|
|
1017
1023
|
- broadcast
|
|
1018
1024
|
- gather_into_tensor
|
|
1019
1025
|
- scatter_tensor
|
|
1026
|
+
- send
|
|
1027
|
+
- recv
|
|
1028
|
+
- isend
|
|
1029
|
+
- irecv
|