mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
- msprobe/README.md +2 -2
- msprobe/core/common/const.py +34 -9
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +14 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -7
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/utils.py +10 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +92 -8
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
- msprobe/core/data_dump/json_writer.py +26 -8
- msprobe/docs/01.installation.md +25 -0
- msprobe/docs/02.config_introduction.md +14 -12
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +34 -15
- msprobe/docs/06.data_dump_MindSpore.md +45 -22
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
- msprobe/docs/19.monitor.md +257 -260
- msprobe/docs/21.visualization_PyTorch.md +10 -0
- msprobe/docs/22.visualization_MindSpore.md +11 -0
- msprobe/docs/27.dump_json_instruction.md +24 -20
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/utils.py +20 -2
- msprobe/mindspore/debugger/debugger_config.py +25 -2
- msprobe/mindspore/debugger/precision_debugger.py +25 -6
- msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/service.py +95 -21
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +71 -0
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +14 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +10 -12
- msprobe/pytorch/monitor/module_hook.py +123 -104
- msprobe/pytorch/monitor/module_metric.py +6 -6
- msprobe/pytorch/monitor/optimizer_collect.py +45 -63
- msprobe/pytorch/monitor/utils.py +8 -43
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +103 -24
- msprobe/visualization/builder/graph_builder.py +31 -5
- msprobe/visualization/builder/msprobe_adapter.py +7 -5
- msprobe/visualization/graph/base_node.py +3 -2
- msprobe/visualization/graph/distributed_analyzer.py +80 -3
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +3 -4
- msprobe/visualization/utils.py +10 -2
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
import abc
|
|
18
|
+
from mindspore import Tensor
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# 用于存储所有validator实现类的注册表
|
|
24
|
+
config_validator_registry = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def register_config_validator(cls):
|
|
28
|
+
"""装饰器 用于注册ConfigValidator的实现类"""
|
|
29
|
+
config_validator_registry[cls.__name__] = cls
|
|
30
|
+
return cls
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ConfigValidator(metaclass=abc.ABCMeta):
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def check_pattern_match(self, config_spec: str):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_config_validator
|
|
44
|
+
class TensorValidator(ConfigValidator):
|
|
45
|
+
def check_pattern_match(self, config_spec: str):
|
|
46
|
+
pattern = re.compile(r"tensor")
|
|
47
|
+
return pattern.match(config_spec)
|
|
48
|
+
|
|
49
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
50
|
+
if not isinstance(actual_data, Tensor):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_config_validator
|
|
56
|
+
class TupleValidator(ConfigValidator):
|
|
57
|
+
def check_pattern_match(self, config_spec: str):
|
|
58
|
+
pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
|
|
59
|
+
return pattern.match(config_spec)
|
|
60
|
+
|
|
61
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
62
|
+
length, index = pattern_match.groups()
|
|
63
|
+
if index is None:
|
|
64
|
+
index = 0
|
|
65
|
+
length, index = int(length), int(index)
|
|
66
|
+
|
|
67
|
+
if not (0 <= index < length):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
|
|
70
|
+
f"y must be greater than or equal to 0 and less than x.")
|
|
71
|
+
if not isinstance(actual_data, tuple):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
|
|
74
|
+
if len(actual_data) != length:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
|
|
77
|
+
f"actual is {len(actual_data)} please check.")
|
|
78
|
+
return index
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
82
|
+
focused_col = None
|
|
83
|
+
for _, validator_cls in config_validator_registry.items():
|
|
84
|
+
config_validator = validator_cls()
|
|
85
|
+
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
86
|
+
if pattern_match:
|
|
87
|
+
try:
|
|
88
|
+
focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
|
|
89
|
+
except ValueError as e:
|
|
90
|
+
logger.warning(f"config spec validate failed: {str(e)}")
|
|
91
|
+
return focused_col
|
|
92
|
+
logger.warning(f"config spec in {module_name} {data_type} not supported, "
|
|
93
|
+
f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
|
|
94
|
+
return focused_col
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from mindspore import dtype as mstype, Tensor
|
|
17
|
+
|
|
18
|
+
from msprobe.mindspore.monitor.features import FUNC_MAP
|
|
19
|
+
from msprobe.core.common.const import MonitorConst
|
|
20
|
+
from msprobe.core.common.utils import is_int
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_single_metrics(op_list, tag, tensor, output=None):
|
|
25
|
+
if output is None:
|
|
26
|
+
output = {}
|
|
27
|
+
if tag not in output:
|
|
28
|
+
output[tag] = {}
|
|
29
|
+
for op in op_list:
|
|
30
|
+
func = FUNC_MAP.get(op)
|
|
31
|
+
statistic = func(tensor)
|
|
32
|
+
if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
|
|
33
|
+
statistic = float(statistic)
|
|
34
|
+
statistic = Tensor(statistic)
|
|
35
|
+
output[tag][op] = statistic.astype(mstype.float32)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_metrics(op_list, tag2tensor, eps, output=None):
|
|
39
|
+
if output is None:
|
|
40
|
+
output = {}
|
|
41
|
+
for tag, tensor in tag2tensor.items():
|
|
42
|
+
if tag not in output:
|
|
43
|
+
output[tag] = {}
|
|
44
|
+
get_single_metrics(op_list, tag, tensor, output)
|
|
45
|
+
return output
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
49
|
+
if rank is None:
|
|
50
|
+
return f"{module_or_param_name}/{tag}"
|
|
51
|
+
else:
|
|
52
|
+
return f"{module_or_param_name}/rank{rank}/{tag}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def step_accumulates_one(context, micro_batch_number):
|
|
56
|
+
"""
|
|
57
|
+
:param context: ModuleHookContext
|
|
58
|
+
:param micro_batch_number: mbs of training model.
|
|
59
|
+
:return:
|
|
60
|
+
"""
|
|
61
|
+
context.micro_step += 1
|
|
62
|
+
if context.micro_step == micro_batch_number:
|
|
63
|
+
context.micro_step = 0
|
|
64
|
+
context.step += 1
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8):
|
|
68
|
+
"""
|
|
69
|
+
If current step less than start_step or not reach step_interval, skip current step.
|
|
70
|
+
:param step: current training step, int
|
|
71
|
+
:param start_step: int
|
|
72
|
+
:param step_interval: int
|
|
73
|
+
:return: whether skip or not, bool
|
|
74
|
+
"""
|
|
75
|
+
return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def validate_ops(ops):
|
|
79
|
+
if not isinstance(ops, list):
|
|
80
|
+
raise TypeError("ops should be a list")
|
|
81
|
+
valid_ops = []
|
|
82
|
+
for op in ops:
|
|
83
|
+
if op not in MonitorConst.OP_LIST:
|
|
84
|
+
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
85
|
+
continue
|
|
86
|
+
valid_ops.append(op)
|
|
87
|
+
if not valid_ops:
|
|
88
|
+
default_op = MonitorConst.OP_LIST[0]
|
|
89
|
+
valid_ops.append(default_op)
|
|
90
|
+
logger.info(f"There is no valid ops, default op {default_op} is used")
|
|
91
|
+
return valid_ops
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def validate_ranks(ranks):
|
|
95
|
+
if not isinstance(ranks, list):
|
|
96
|
+
raise TypeError("module_ranks should be a list")
|
|
97
|
+
for rank in ranks:
|
|
98
|
+
if not isinstance(rank, str):
|
|
99
|
+
raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate_targets(targets):
|
|
103
|
+
if not isinstance(targets, dict):
|
|
104
|
+
raise TypeError('targets in config.json should be a dict')
|
|
105
|
+
for module_name, field in targets.items():
|
|
106
|
+
if not isinstance(module_name, str):
|
|
107
|
+
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
108
|
+
if not isinstance(field, dict):
|
|
109
|
+
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def validate_print_struct(print_struct):
|
|
113
|
+
if not isinstance(print_struct, bool):
|
|
114
|
+
raise TypeError("print_struct should be a bool")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def validate_ur_distribution(ur_distribution):
|
|
118
|
+
if not isinstance(ur_distribution, bool):
|
|
119
|
+
raise TypeError('ur_distribution should be a bool')
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def validate_xy_distribution(xy_distribution):
|
|
123
|
+
if not isinstance(xy_distribution, bool):
|
|
124
|
+
raise TypeError('xy_distribution should be a bool')
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def validate_wg_distribution(wg_distribution):
|
|
128
|
+
if not isinstance(wg_distribution, bool):
|
|
129
|
+
raise TypeError('wg_distribution should be a bool')
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def validate_mg_distribution(mg_distribution):
|
|
133
|
+
if not isinstance(mg_distribution, bool):
|
|
134
|
+
raise TypeError('mg_distribution should be a bool')
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def validate_param_distribution(param_distribution):
|
|
138
|
+
if not isinstance(param_distribution, bool):
|
|
139
|
+
raise TypeError('param_distribution should be a bool')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def validate_cc_distribution(cc_distribution):
|
|
143
|
+
if not isinstance(cc_distribution, dict):
|
|
144
|
+
raise TypeError('cc_distribution should be a dictionary')
|
|
145
|
+
expected_keys = {
|
|
146
|
+
'enable': bool,
|
|
147
|
+
'cc_codeline': list,
|
|
148
|
+
'cc_pre_hook': bool,
|
|
149
|
+
'cc_log_only': bool
|
|
150
|
+
}
|
|
151
|
+
for key, value in cc_distribution.items():
|
|
152
|
+
if key in expected_keys:
|
|
153
|
+
if not isinstance(value, expected_keys[key]):
|
|
154
|
+
raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
|
|
155
|
+
else:
|
|
156
|
+
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate_alert(alert):
|
|
160
|
+
if not isinstance(alert, dict):
|
|
161
|
+
raise TypeError('alert should be a dictionary')
|
|
162
|
+
rules = alert.get('rules')
|
|
163
|
+
if rules and isinstance(rules, list):
|
|
164
|
+
for rule in rules:
|
|
165
|
+
rule_name = rule.get("rule_name")
|
|
166
|
+
if rule_name and rule_name not in MonitorConst.RULE_NAME:
|
|
167
|
+
raise TypeError(f"{rule_name} is not supported")
|
|
168
|
+
args = rule.get("args")
|
|
169
|
+
if args and isinstance(args, dict):
|
|
170
|
+
threshold = args.get("threshold")
|
|
171
|
+
if not isinstance(threshold, float) or threshold < 0:
|
|
172
|
+
raise TypeError('threshold must be float and not less than 0')
|
|
173
|
+
dump = alert.get('dump')
|
|
174
|
+
if dump and not isinstance(dump, bool):
|
|
175
|
+
raise TypeError('dump must be bool.')
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def validate_step_count_per_record(step_count_per_record):
|
|
179
|
+
if not is_int(step_count_per_record):
|
|
180
|
+
raise TypeError('step_count_per_record must be int.')
|
|
181
|
+
if step_count_per_record < 1:
|
|
182
|
+
raise ValueError("step_count_per_record must greater than 0")
|
|
183
|
+
if step_count_per_record > 1e6:
|
|
184
|
+
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def validate_start_step(start_step):
|
|
188
|
+
if not is_int(start_step):
|
|
189
|
+
raise TypeError('start_step must be int.')
|
|
190
|
+
if start_step < 0:
|
|
191
|
+
raise ValueError("start_step must greater than 0")
|
|
192
|
+
if start_step > 1e8:
|
|
193
|
+
raise ValueError("start_step must smaller than 1e8")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def validate_step_interval(step_interval):
|
|
197
|
+
if not is_int(step_interval):
|
|
198
|
+
raise TypeError('step_interval must be int.')
|
|
199
|
+
if step_interval < 1:
|
|
200
|
+
raise ValueError("step_interval must greater than 1")
|
|
201
|
+
if step_interval > 1e8:
|
|
202
|
+
raise ValueError("step_interval must smaller than 1e8")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def validate_collect_times(collect_times):
|
|
206
|
+
if not is_int(collect_times):
|
|
207
|
+
raise TypeError('collect_times must be int.')
|
|
208
|
+
if collect_times < 1:
|
|
209
|
+
raise ValueError("collect_times must greater than 1")
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def validate_config(config):
|
|
213
|
+
config['ops'] = validate_ops(config.get('ops', []))
|
|
214
|
+
|
|
215
|
+
eps = config.get('eps', 1e-8)
|
|
216
|
+
if not isinstance(eps, float):
|
|
217
|
+
raise TypeError("eps should be a float")
|
|
218
|
+
|
|
219
|
+
ranks = config.get("module_ranks", [])
|
|
220
|
+
validate_ranks(ranks)
|
|
221
|
+
|
|
222
|
+
targets = config.get("targets", {})
|
|
223
|
+
validate_targets(targets)
|
|
224
|
+
|
|
225
|
+
print_struct = config.get('print_struct', False)
|
|
226
|
+
validate_print_struct(print_struct)
|
|
227
|
+
|
|
228
|
+
ur_distribution = config.get('ur_distribution', False)
|
|
229
|
+
validate_ur_distribution(ur_distribution)
|
|
230
|
+
|
|
231
|
+
xy_distribution = config.get('xy_distribution', False)
|
|
232
|
+
validate_xy_distribution(xy_distribution)
|
|
233
|
+
|
|
234
|
+
wg_distribution = config.get('wg_distribution', False)
|
|
235
|
+
validate_wg_distribution(wg_distribution)
|
|
236
|
+
|
|
237
|
+
mg_distribution = config.get('mg_distribution', False)
|
|
238
|
+
validate_mg_distribution(mg_distribution)
|
|
239
|
+
|
|
240
|
+
param_distribution = config.get('param_distribution', False)
|
|
241
|
+
validate_param_distribution(param_distribution)
|
|
242
|
+
|
|
243
|
+
cc_distribution = config.get('cc_distribution', {})
|
|
244
|
+
validate_cc_distribution(cc_distribution)
|
|
245
|
+
|
|
246
|
+
alert = config.get('alert', {})
|
|
247
|
+
validate_alert(alert)
|
|
248
|
+
|
|
249
|
+
step_count_per_record = config.get('step_count_per_record', 1)
|
|
250
|
+
validate_step_count_per_record(step_count_per_record)
|
|
251
|
+
|
|
252
|
+
start_step = config.get('start_step', 0)
|
|
253
|
+
validate_start_step(start_step)
|
|
254
|
+
|
|
255
|
+
step_interval = config.get('step_interval', 1)
|
|
256
|
+
validate_step_interval(step_interval)
|
|
257
|
+
|
|
258
|
+
collect_times = config.get('collect_times', 1e8)
|
|
259
|
+
validate_collect_times(collect_times)
|
|
260
|
+
|
|
261
|
+
if not targets:
|
|
262
|
+
if xy_distribution:
|
|
263
|
+
config["all_xy"] = True
|
|
264
|
+
config["targets"] = {"": {}}
|
|
265
|
+
config["is_select"] = False
|
|
266
|
+
else:
|
|
267
|
+
config["is_select"] = True
|
msprobe/mindspore/ms_config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -106,12 +106,18 @@ class GradProbeConfig(BaseConfig):
|
|
|
106
106
|
check_numeral_list_ascend(self.bounds)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
class StructureConfig(BaseConfig):
|
|
110
|
+
def __init__(self, json_config):
|
|
111
|
+
super().__init__(json_config)
|
|
112
|
+
|
|
113
|
+
|
|
109
114
|
TaskDict = {
|
|
110
115
|
Const.TENSOR: TensorConfig,
|
|
111
116
|
Const.STATISTICS: StatisticsConfig,
|
|
112
117
|
Const.OVERFLOW_CHECK: OverflowCheckConfig,
|
|
113
118
|
Const.FREE_BENCHMARK: FreeBenchmarkConfig,
|
|
114
|
-
Const.GRAD_PROBE: GradProbeConfig
|
|
119
|
+
Const.GRAD_PROBE: GradProbeConfig,
|
|
120
|
+
Const.STRUCTURE: StructureConfig
|
|
115
121
|
}
|
|
116
122
|
|
|
117
123
|
|
msprobe/mindspore/service.py
CHANGED
|
@@ -22,6 +22,7 @@ import mindspore as ms
|
|
|
22
22
|
from mindspore import nn
|
|
23
23
|
from mindspore.common.api import _no_grad
|
|
24
24
|
from mindspore.ops.primitive import Primitive
|
|
25
|
+
|
|
25
26
|
try:
|
|
26
27
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
28
|
except ImportError:
|
|
@@ -31,7 +32,7 @@ else:
|
|
|
31
32
|
|
|
32
33
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
33
34
|
from msprobe.core.common.file_utils import create_directory
|
|
34
|
-
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
35
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
35
36
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
36
37
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
37
38
|
ModuleBackwardInputs)
|
|
@@ -68,8 +69,10 @@ class Service:
|
|
|
68
69
|
self.start_call = False
|
|
69
70
|
self.should_stop_service = False
|
|
70
71
|
self.params_grad_info = {}
|
|
72
|
+
self.hook_handle_dict = {}
|
|
71
73
|
# 提前注册,确保注册尽可能多的API hook
|
|
72
74
|
self.register_api_hook()
|
|
75
|
+
self.init_for_debug_level()
|
|
73
76
|
|
|
74
77
|
@staticmethod
|
|
75
78
|
def check_model_valid(models):
|
|
@@ -138,7 +141,12 @@ class Service:
|
|
|
138
141
|
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
139
142
|
for param_name, param in params_dict.items():
|
|
140
143
|
if param.requires_grad:
|
|
141
|
-
|
|
144
|
+
name = ori_name + Const.SEP + param_name
|
|
145
|
+
old_handle = self.hook_handle_dict.get(name)
|
|
146
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
147
|
+
old_handle.remove()
|
|
148
|
+
handle = param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
149
|
+
self.hook_handle_dict[name] = handle
|
|
142
150
|
|
|
143
151
|
def init_params_grad_info(cell, params_dict):
|
|
144
152
|
'''
|
|
@@ -168,11 +176,15 @@ class Service:
|
|
|
168
176
|
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
169
177
|
if target_type == BaseScope.Module_Type_Module:
|
|
170
178
|
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
171
|
-
params_dict = {
|
|
172
|
-
|
|
173
|
-
|
|
179
|
+
params_dict = {}
|
|
180
|
+
if self.config.task != Const.STRUCTURE:
|
|
181
|
+
params_dict = {
|
|
182
|
+
key.split(Const.SEP)[-1]: value
|
|
183
|
+
for key, value in cell.parameters_dict(recurse=False).items()
|
|
184
|
+
}
|
|
185
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
174
186
|
# 判断是否需要注册参数hook
|
|
175
|
-
if
|
|
187
|
+
if params_dict:
|
|
176
188
|
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
177
189
|
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
178
190
|
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
@@ -257,15 +269,20 @@ class Service:
|
|
|
257
269
|
self.primitive_counters[primitive_name] += 1
|
|
258
270
|
|
|
259
271
|
def step(self):
|
|
272
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
273
|
+
return
|
|
260
274
|
if self.config.async_dump:
|
|
261
275
|
self.data_collector.fill_stack_tensor_data()
|
|
262
|
-
self.
|
|
276
|
+
if self.config.task == Const.TENSOR:
|
|
277
|
+
self.data_collector.data_processor.dump_async_data()
|
|
263
278
|
self.data_collector.write_json()
|
|
264
279
|
self.current_iter += 1
|
|
265
280
|
self.data_collector.update_iter(self.current_iter)
|
|
266
281
|
self.reset_status()
|
|
267
282
|
|
|
268
283
|
def start(self, model=None):
|
|
284
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
285
|
+
return
|
|
269
286
|
self.start_call = True
|
|
270
287
|
if self.should_stop_service:
|
|
271
288
|
return
|
|
@@ -294,7 +311,10 @@ class Service:
|
|
|
294
311
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
295
312
|
JitDump.set_config(self.config)
|
|
296
313
|
JitDump.set_data_collector(self.data_collector)
|
|
297
|
-
ms.common.api
|
|
314
|
+
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
315
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
316
|
+
else:
|
|
317
|
+
ms.common.api._JitExecutor = JitDump
|
|
298
318
|
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
299
319
|
if pijit_label:
|
|
300
320
|
PIJitCaptureContext.__enter__ = self.empty
|
|
@@ -310,6 +330,8 @@ class Service:
|
|
|
310
330
|
JitDump.jit_dump_switch = True
|
|
311
331
|
|
|
312
332
|
def stop(self):
|
|
333
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
334
|
+
return
|
|
313
335
|
if self.should_stop_service:
|
|
314
336
|
return
|
|
315
337
|
logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
@@ -326,7 +348,8 @@ class Service:
|
|
|
326
348
|
self.start_call = False
|
|
327
349
|
if self.config.async_dump:
|
|
328
350
|
self.data_collector.fill_stack_tensor_data()
|
|
329
|
-
self.
|
|
351
|
+
if self.config.task == Const.TENSOR:
|
|
352
|
+
self.data_collector.data_processor.dump_async_data()
|
|
330
353
|
self.data_collector.write_json()
|
|
331
354
|
JitDump.jit_dump_switch = False
|
|
332
355
|
|
|
@@ -370,12 +393,13 @@ class Service:
|
|
|
370
393
|
else:
|
|
371
394
|
dump_data_dir = None
|
|
372
395
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
)
|
|
396
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
397
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
398
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
399
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
400
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
401
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
402
|
+
|
|
379
403
|
self.data_collector.initialize_json_file(
|
|
380
404
|
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
381
405
|
)
|
|
@@ -394,13 +418,13 @@ class Service:
|
|
|
394
418
|
|
|
395
419
|
def get_cell_or_module(model):
|
|
396
420
|
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
397
|
-
|
|
421
|
+
|
|
398
422
|
if isinstance(self.model, (list, tuple)):
|
|
399
423
|
for index, model in enumerate(self.model):
|
|
400
424
|
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
401
425
|
else:
|
|
402
426
|
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
403
|
-
return cells_and_names_with_index
|
|
427
|
+
return cells_and_names_with_index
|
|
404
428
|
|
|
405
429
|
def register_primitive_hook(self):
|
|
406
430
|
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
@@ -430,7 +454,7 @@ class Service:
|
|
|
430
454
|
if not self.model:
|
|
431
455
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
432
456
|
f"The current level is {self.config.level}, the model cannot be None")
|
|
433
|
-
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
457
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
434
458
|
cells_and_names_with_index = self.get_cells_and_names()
|
|
435
459
|
|
|
436
460
|
for index, cells_and_names in cells_and_names_with_index.items():
|
|
@@ -439,7 +463,7 @@ class Service:
|
|
|
439
463
|
if cell == model:
|
|
440
464
|
continue
|
|
441
465
|
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
442
|
-
prefix = (model_type + Const.SEP + cell_index + name +
|
|
466
|
+
prefix = (model_type + Const.SEP + cell_index + name +
|
|
443
467
|
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
444
468
|
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
445
469
|
cell.register_forward_hook(forward_hook)
|
|
@@ -456,10 +480,9 @@ class Service:
|
|
|
456
480
|
|
|
457
481
|
def reset_status(self):
|
|
458
482
|
self.primitive_hook_service.primitive_counters.clear()
|
|
459
|
-
self.data_collector.
|
|
483
|
+
self.data_collector.reset_status()
|
|
460
484
|
JitDump.jit_count = defaultdict(int)
|
|
461
485
|
self.params_grad_info.clear()
|
|
462
|
-
|
|
463
486
|
if self.config.level == Const.LEVEL_L2:
|
|
464
487
|
self.data_collector.data_processor.reset_status()
|
|
465
488
|
return
|
|
@@ -467,3 +490,54 @@ class Service:
|
|
|
467
490
|
return
|
|
468
491
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
469
492
|
return
|
|
493
|
+
|
|
494
|
+
def init_for_debug_level(self):
|
|
495
|
+
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
496
|
+
return
|
|
497
|
+
try:
|
|
498
|
+
self.current_rank = get_rank_if_initialized()
|
|
499
|
+
except DistributedNotInitializedError:
|
|
500
|
+
self.current_rank = None
|
|
501
|
+
# dir: dump_path -- rank{} -- debug.json
|
|
502
|
+
self.dump_iter_dir = self.config.dump_path
|
|
503
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
504
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
505
|
+
create_directory(dump_dir)
|
|
506
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
507
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
508
|
+
create_directory(dump_data_dir)
|
|
509
|
+
else:
|
|
510
|
+
dump_data_dir = None
|
|
511
|
+
|
|
512
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
513
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
514
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
515
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
516
|
+
self.data_collector.initialize_json_file(
|
|
517
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
518
|
+
)
|
|
519
|
+
self.debug_variable_counter = defaultdict(int)
|
|
520
|
+
|
|
521
|
+
def save(self, variable, name, save_backward):
|
|
522
|
+
'''
|
|
523
|
+
Args:
|
|
524
|
+
variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
|
|
525
|
+
name: str
|
|
526
|
+
save_backward: boolean
|
|
527
|
+
Return:
|
|
528
|
+
void
|
|
529
|
+
'''
|
|
530
|
+
if self.config.level != Const.LEVEL_DEBUG:
|
|
531
|
+
return
|
|
532
|
+
count = self.debug_variable_counter[name]
|
|
533
|
+
self.debug_variable_counter[name] += 1
|
|
534
|
+
|
|
535
|
+
name_with_count = f"{name}.{count}"
|
|
536
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
537
|
+
|
|
538
|
+
# forward save
|
|
539
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
540
|
+
|
|
541
|
+
# backward save
|
|
542
|
+
if save_backward:
|
|
543
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -399,7 +399,7 @@ class OperatorScriptGenerator:
|
|
|
399
399
|
def generate_kwargs_dict(self, kwargs_info, flag_device):
|
|
400
400
|
kwargs_dict_generator = ""
|
|
401
401
|
for key, value in kwargs_info.items():
|
|
402
|
-
kwargs_dict_generator += '"' + key + '"' + MonitorConst.
|
|
402
|
+
kwargs_dict_generator += '"' + key + '"' + MonitorConst.NAME_SEP
|
|
403
403
|
if flag_device:
|
|
404
404
|
kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
|
|
405
405
|
else:
|