mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
- msprobe/README.md +2 -2
- msprobe/core/common/const.py +34 -9
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +14 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -7
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/utils.py +10 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +92 -8
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
- msprobe/core/data_dump/json_writer.py +26 -8
- msprobe/docs/01.installation.md +25 -0
- msprobe/docs/02.config_introduction.md +14 -12
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +34 -15
- msprobe/docs/06.data_dump_MindSpore.md +45 -22
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
- msprobe/docs/19.monitor.md +257 -260
- msprobe/docs/21.visualization_PyTorch.md +10 -0
- msprobe/docs/22.visualization_MindSpore.md +11 -0
- msprobe/docs/27.dump_json_instruction.md +24 -20
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/utils.py +20 -2
- msprobe/mindspore/debugger/debugger_config.py +25 -2
- msprobe/mindspore/debugger/precision_debugger.py +25 -6
- msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/service.py +95 -21
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +71 -0
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +14 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +10 -12
- msprobe/pytorch/monitor/module_hook.py +123 -104
- msprobe/pytorch/monitor/module_metric.py +6 -6
- msprobe/pytorch/monitor/optimizer_collect.py +45 -63
- msprobe/pytorch/monitor/utils.py +8 -43
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +103 -24
- msprobe/visualization/builder/graph_builder.py +31 -5
- msprobe/visualization/builder/msprobe_adapter.py +7 -5
- msprobe/visualization/graph/base_node.py +3 -2
- msprobe/visualization/graph/distributed_analyzer.py +80 -3
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +3 -4
- msprobe/visualization/utils.py +10 -2
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import gc
|
|
18
|
+
import sys
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
import mindspore
|
|
21
|
+
from msprobe.mindspore.common.log import logger
|
|
22
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
23
|
+
import torch as mindtorch
|
|
24
|
+
from torch import Tensor as mindtorch_tensor
|
|
25
|
+
import torch.nn.functional as mindtorch_func
|
|
26
|
+
import torch.distributed as mindtorch_dist
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
is_valid_pt_mt_env = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def is_mindtorch():
|
|
33
|
+
mindtorch_check_result = False
|
|
34
|
+
try:
|
|
35
|
+
import torch as test_torch
|
|
36
|
+
from mindspore import Tensor as MindsporeTensor
|
|
37
|
+
except ImportError:
|
|
38
|
+
return mindtorch_check_result
|
|
39
|
+
tensor = test_torch.tensor(0.0)
|
|
40
|
+
if isinstance(tensor, MindsporeTensor):
|
|
41
|
+
mindtorch_check_result = True
|
|
42
|
+
|
|
43
|
+
return mindtorch_check_result
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def remove_torch_related_paths():
|
|
47
|
+
removed_paths = []
|
|
48
|
+
if not is_mindtorch():
|
|
49
|
+
return
|
|
50
|
+
try:
|
|
51
|
+
import torch as remove_torch
|
|
52
|
+
torch_file = remove_torch.__file__
|
|
53
|
+
except ImportError:
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
torch_dir = os.path.dirname(torch_file)
|
|
57
|
+
|
|
58
|
+
torch_dir_path = Path(torch_dir).resolve()
|
|
59
|
+
parent_dir = torch_dir_path.parent
|
|
60
|
+
|
|
61
|
+
paths_to_remove = [str(parent_dir)]
|
|
62
|
+
|
|
63
|
+
for path in paths_to_remove:
|
|
64
|
+
try:
|
|
65
|
+
path_resolved = str(Path(path).resolve())
|
|
66
|
+
except Exception as error:
|
|
67
|
+
logger.debug(f"Failed to resolve path {path}: {error}")
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
if path_resolved in sys.path:
|
|
71
|
+
index = sys.path.index(path_resolved)
|
|
72
|
+
removed_paths.append((path_resolved, index))
|
|
73
|
+
sys.path.pop(index)
|
|
74
|
+
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def clear_torch_from_sys_modules():
|
|
79
|
+
modules_to_remove = []
|
|
80
|
+
for module in sys.modules:
|
|
81
|
+
if module == "torch" or module.startswith("torch."):
|
|
82
|
+
modules_to_remove.append(module)
|
|
83
|
+
|
|
84
|
+
for module in modules_to_remove:
|
|
85
|
+
del sys.modules[module]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def set_pt_mt_env_invalid():
|
|
89
|
+
global is_valid_pt_mt_env
|
|
90
|
+
is_valid_pt_mt_env = False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def delete_torch_paths():
|
|
94
|
+
|
|
95
|
+
if not is_mindtorch():
|
|
96
|
+
set_pt_mt_env_invalid()
|
|
97
|
+
|
|
98
|
+
clear_torch_from_sys_modules()
|
|
99
|
+
|
|
100
|
+
for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH):
|
|
101
|
+
if not is_mindtorch():
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
remove_torch_related_paths()
|
|
105
|
+
|
|
106
|
+
clear_torch_from_sys_modules()
|
|
107
|
+
|
|
108
|
+
if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
|
|
109
|
+
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
|
|
110
|
+
f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
if not is_mindtorch():
|
|
114
|
+
set_pt_mt_env_invalid()
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
initial_sys_path = sys.path.copy()
|
|
118
|
+
delete_torch_paths()
|
|
119
|
+
|
|
120
|
+
gc.collect()
|
|
121
|
+
|
|
122
|
+
import torch
|
|
123
|
+
|
|
124
|
+
if is_mindtorch():
|
|
125
|
+
set_pt_mt_env_invalid()
|
|
126
|
+
|
|
127
|
+
sys.path = initial_sys_path
|
|
128
|
+
|
|
129
|
+
|
|
@@ -15,10 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
import mindspore
|
|
17
17
|
import numpy as np
|
|
18
|
-
import torch
|
|
19
18
|
from mindspore._c_expression import typing
|
|
20
19
|
from mindspore.common import dtype as mstype
|
|
21
20
|
|
|
21
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
22
|
+
|
|
23
|
+
if torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
24
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
26
|
+
else:
|
|
27
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
28
|
+
import torch
|
|
29
|
+
|
|
22
30
|
INT8 = "Int8"
|
|
23
31
|
UINT8 = "UInt8"
|
|
24
32
|
INT16 = "Int16"
|
|
@@ -82,6 +90,21 @@ dtype_str_to_torch_dtype = {
|
|
|
82
90
|
}
|
|
83
91
|
torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
|
|
84
92
|
|
|
93
|
+
|
|
94
|
+
dtype_str_to_mindtorch_dtype = {
|
|
95
|
+
INT8: mindtorch.int8,
|
|
96
|
+
UINT8: mindtorch.uint8,
|
|
97
|
+
INT16: mindtorch.int16,
|
|
98
|
+
INT32: mindtorch.int32,
|
|
99
|
+
INT64: mindtorch.int64,
|
|
100
|
+
FLOAT16: mindtorch.float16,
|
|
101
|
+
FLOAT32: mindtorch.float32,
|
|
102
|
+
FLOAT64: mindtorch.float64,
|
|
103
|
+
BOOL: mindtorch.bool,
|
|
104
|
+
BFLOAT16: mindtorch.bfloat16,
|
|
105
|
+
}
|
|
106
|
+
mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()}
|
|
107
|
+
|
|
85
108
|
MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
|
|
86
109
|
BOOL_TYPE_STR = "bool"
|
|
87
110
|
INT_TYPE_STR = "int"
|
|
@@ -82,10 +82,12 @@ class GlobalContext:
|
|
|
82
82
|
def __init__(self):
|
|
83
83
|
self.is_constructed = True
|
|
84
84
|
self.dump_data_dir = ""
|
|
85
|
+
self.framework = Const.MS_FRAMEWORK
|
|
85
86
|
|
|
86
|
-
def init(self, is_constructed, dump_data_dir):
|
|
87
|
+
def init(self, is_constructed, dump_data_dir, framework):
|
|
87
88
|
self.is_constructed = is_constructed
|
|
88
89
|
self.dump_data_dir = dump_data_dir
|
|
90
|
+
self.framework = framework
|
|
89
91
|
|
|
90
92
|
def get_dump_data_dir(self):
|
|
91
93
|
return self.dump_data_dir
|
|
@@ -93,5 +95,8 @@ class GlobalContext:
|
|
|
93
95
|
def get_is_constructed(self):
|
|
94
96
|
return self.is_constructed
|
|
95
97
|
|
|
98
|
+
def get_framework(self):
|
|
99
|
+
return self.framework
|
|
100
|
+
|
|
96
101
|
|
|
97
102
|
global_context = GlobalContext()
|
|
@@ -151,11 +151,10 @@ def is_mindtorch():
|
|
|
151
151
|
mindtorch_check_result = False
|
|
152
152
|
try:
|
|
153
153
|
import torch
|
|
154
|
-
from mindspore._c_expression import Tensor
|
|
155
154
|
except ImportError:
|
|
156
155
|
return mindtorch_check_result
|
|
157
156
|
tensor = torch.tensor(0.0)
|
|
158
|
-
if isinstance(tensor, Tensor):
|
|
157
|
+
if isinstance(tensor, ms.Tensor):
|
|
159
158
|
mindtorch_check_result = True
|
|
160
159
|
return mindtorch_check_result
|
|
161
160
|
|
|
@@ -179,3 +178,22 @@ def set_register_backward_hook_functions():
|
|
|
179
178
|
else:
|
|
180
179
|
register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook
|
|
181
180
|
register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def check_save_param(variable, name, save_backward):
|
|
184
|
+
# try catch this api to skip invalid call
|
|
185
|
+
if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)):
|
|
186
|
+
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
187
|
+
"should be one of list, dict, ms.Tensor, int, float or string. "
|
|
188
|
+
"Skip current save process.")
|
|
189
|
+
raise ValueError
|
|
190
|
+
if not isinstance(name, str):
|
|
191
|
+
logger.warning("PrecisionDebugger.save name not valid, "
|
|
192
|
+
"should be string. "
|
|
193
|
+
"skip current save process.")
|
|
194
|
+
raise ValueError
|
|
195
|
+
if not isinstance(save_backward, bool):
|
|
196
|
+
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
197
|
+
"should be bool. "
|
|
198
|
+
"Skip current save process.")
|
|
199
|
+
raise ValueError
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +16,11 @@
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
19
20
|
from msprobe.core.common.file_utils import create_directory
|
|
20
21
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
21
22
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class DebuggerConfig:
|
|
@@ -50,7 +52,7 @@ class DebuggerConfig:
|
|
|
50
52
|
if not task_config.handler_type else task_config.handler_type)
|
|
51
53
|
self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage
|
|
52
54
|
if self.handler_type == FreeBenchmarkConst.FIX and \
|
|
53
|
-
|
|
55
|
+
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
54
56
|
raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
55
57
|
f"but got {self.pert_type}.")
|
|
56
58
|
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
@@ -72,4 +74,25 @@ class DebuggerConfig:
|
|
|
72
74
|
self.check_mode = "all"
|
|
73
75
|
if not isinstance(self.async_dump, bool):
|
|
74
76
|
raise Exception("The parameters async_dump should be bool.")
|
|
77
|
+
if self.async_dump and self.task == Const.TENSOR and not self.list:
|
|
78
|
+
raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.")
|
|
79
|
+
if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
80
|
+
logger.warning_on_rank_0(
|
|
81
|
+
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
82
|
+
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
83
|
+
)
|
|
84
|
+
self.level_ori = Const.LEVEL_MIX
|
|
75
85
|
return True
|
|
86
|
+
|
|
87
|
+
def check_config_with_l2(self):
|
|
88
|
+
if self.level_ori != Const.LEVEL_L2:
|
|
89
|
+
return
|
|
90
|
+
if self.task != Const.TENSOR:
|
|
91
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
92
|
+
f"When level is set to L2, the task must be set to tensor.")
|
|
93
|
+
if self.scope:
|
|
94
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
95
|
+
f"When level is set to L2, the scope cannot be configured.")
|
|
96
|
+
if not self.list or len(self.list) != 1:
|
|
97
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
98
|
+
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
@@ -25,7 +25,7 @@ from msprobe.core.common.file_utils import FileChecker
|
|
|
25
25
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
26
26
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
27
27
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
28
|
-
from msprobe.mindspore.common.utils import set_register_backward_hook_functions
|
|
28
|
+
from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
|
|
29
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
30
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
31
31
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
@@ -89,6 +89,7 @@ class PrecisionDebugger:
|
|
|
89
89
|
|
|
90
90
|
self.config.execution_mode = self._get_execution_mode()
|
|
91
91
|
if self._need_service():
|
|
92
|
+
self.config.check_config_with_l2()
|
|
92
93
|
self.service = Service(self.config)
|
|
93
94
|
|
|
94
95
|
Runtime.step_count = 0
|
|
@@ -139,11 +140,11 @@ class PrecisionDebugger:
|
|
|
139
140
|
def _is_graph_dump(config):
|
|
140
141
|
if config.level != MsConst.KERNEL:
|
|
141
142
|
return False
|
|
142
|
-
if not config.list
|
|
143
|
+
if not config.list:
|
|
143
144
|
return True
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
return
|
|
145
|
+
is_graph = any(item.startswith("name-regex") for item in config.list)
|
|
146
|
+
is_graph |= all("." not in item for item in config.list)
|
|
147
|
+
return is_graph
|
|
147
148
|
|
|
148
149
|
@classmethod
|
|
149
150
|
def start(cls, model=None):
|
|
@@ -214,6 +215,24 @@ class PrecisionDebugger:
|
|
|
214
215
|
return
|
|
215
216
|
instance.gm.monitor(opt)
|
|
216
217
|
|
|
218
|
+
@classmethod
|
|
219
|
+
def save(cls, variable, name, save_backward=True):
|
|
220
|
+
instance = cls._instance
|
|
221
|
+
if not instance:
|
|
222
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
223
|
+
if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level_ori != Const.LEVEL_DEBUG:
|
|
224
|
+
return
|
|
225
|
+
try:
|
|
226
|
+
check_save_param(variable, name, save_backward)
|
|
227
|
+
except ValueError:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
instance.config.execution_mode = cls._get_execution_mode()
|
|
231
|
+
if cls._need_service():
|
|
232
|
+
if not instance.service:
|
|
233
|
+
instance.service = Service(instance.config)
|
|
234
|
+
instance.service.save(variable, name, save_backward)
|
|
235
|
+
|
|
217
236
|
@classmethod
|
|
218
237
|
def _need_service(cls):
|
|
219
238
|
instance = cls._instance
|
|
@@ -222,4 +241,4 @@ class PrecisionDebugger:
|
|
|
222
241
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
223
242
|
return False
|
|
224
243
|
else:
|
|
225
|
-
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
244
|
+
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
@@ -106,6 +106,7 @@ class ApiRegistry:
|
|
|
106
106
|
self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
|
|
107
107
|
self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
|
|
108
108
|
self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
|
|
109
|
+
self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr)
|
|
109
110
|
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
110
111
|
else:
|
|
111
112
|
self.set_api_attr(Tensor, self.tensor_hook_attr)
|
|
@@ -121,6 +122,7 @@ class ApiRegistry:
|
|
|
121
122
|
self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
|
|
122
123
|
self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
|
|
123
124
|
self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
|
|
125
|
+
self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr)
|
|
124
126
|
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
125
127
|
else:
|
|
126
128
|
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
@@ -16,14 +16,15 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from collections import defaultdict
|
|
18
18
|
|
|
19
|
-
from mindspore import Tensor
|
|
20
19
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
|
-
|
|
20
|
+
try:
|
|
21
|
+
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
22
|
+
except ImportError:
|
|
23
|
+
from mindspore.common.api import _JitExecutor as _MindsporeFunctionExecutor
|
|
22
24
|
|
|
23
25
|
from msprobe.core.common.log import logger
|
|
24
|
-
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
25
26
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
27
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
27
28
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
28
29
|
|
|
29
30
|
|
|
@@ -40,8 +41,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
40
41
|
if JitDump.need_dump():
|
|
41
42
|
if is_forward:
|
|
42
43
|
JitDump.jit_count[result] += 1
|
|
43
|
-
name_template = Const.JIT + Const.SEP + result + Const.SEP +
|
|
44
|
-
|
|
44
|
+
name_template = (Const.JIT + Const.SEP + result + Const.SEP +
|
|
45
|
+
str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD)
|
|
45
46
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
46
47
|
module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
|
|
47
48
|
JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
|