mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -12,20 +12,9 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import inspect
|
|
16
|
-
from collections import namedtuple
|
|
17
|
-
from datetime import timezone, timedelta
|
|
18
|
-
from functools import wraps
|
|
19
|
-
from datetime import datetime
|
|
20
|
-
import os
|
|
21
|
-
import re
|
|
22
|
-
|
|
23
15
|
import torch
|
|
24
16
|
|
|
25
|
-
from msprobe.core.common.const import MonitorConst
|
|
26
17
|
from msprobe.pytorch.common.log import logger
|
|
27
|
-
from msprobe.core.common.utils import is_int
|
|
28
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
29
18
|
|
|
30
19
|
|
|
31
20
|
device = "cpu"
|
|
@@ -37,23 +26,6 @@ except ImportError:
|
|
|
37
26
|
device = "cuda"
|
|
38
27
|
|
|
39
28
|
NAN_TENSOR_ON_DEVICE = None
|
|
40
|
-
FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
41
|
-
FILE_NAME_MAX_LENGTH = 255
|
|
42
|
-
DIRECTORY_MAX_LENGTH = 4096
|
|
43
|
-
|
|
44
|
-
beijing_tz = timezone(timedelta(hours=8))
|
|
45
|
-
MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class MsgConst:
|
|
49
|
-
"""
|
|
50
|
-
Class for log messages const
|
|
51
|
-
"""
|
|
52
|
-
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_output_base_dir():
|
|
56
|
-
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
57
29
|
|
|
58
30
|
|
|
59
31
|
def get_nan_tensor():
|
|
@@ -63,16 +35,6 @@ def get_nan_tensor():
|
|
|
63
35
|
return NAN_TENSOR_ON_DEVICE
|
|
64
36
|
|
|
65
37
|
|
|
66
|
-
def filter_special_chars(func):
|
|
67
|
-
@wraps(func)
|
|
68
|
-
def func_level(msg):
|
|
69
|
-
for char in MsgConst.SPECIAL_CHAR:
|
|
70
|
-
msg = msg.replace(char, '_')
|
|
71
|
-
return func(msg)
|
|
72
|
-
|
|
73
|
-
return func_level
|
|
74
|
-
|
|
75
|
-
|
|
76
38
|
def get_param_struct(param):
|
|
77
39
|
res = {}
|
|
78
40
|
if isinstance(param, (tuple, list)):
|
|
@@ -85,282 +47,4 @@ def get_param_struct(param):
|
|
|
85
47
|
else:
|
|
86
48
|
res['config'] = f'{type(param)}'
|
|
87
49
|
logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}')
|
|
88
|
-
return res
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def validate_ops(ops):
|
|
92
|
-
if not isinstance(ops, list):
|
|
93
|
-
raise TypeError("ops should be a list")
|
|
94
|
-
valid_ops = []
|
|
95
|
-
for op in ops:
|
|
96
|
-
if op not in MonitorConst.OP_LIST:
|
|
97
|
-
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
98
|
-
continue
|
|
99
|
-
valid_ops.append(op)
|
|
100
|
-
if not valid_ops:
|
|
101
|
-
default_op = MonitorConst.OP_LIST[0]
|
|
102
|
-
valid_ops.append(default_op)
|
|
103
|
-
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
104
|
-
# 增加默认shape和dtype参数
|
|
105
|
-
if "shape" not in valid_ops:
|
|
106
|
-
valid_ops.append("shape")
|
|
107
|
-
if "dtype" not in valid_ops:
|
|
108
|
-
valid_ops.append("dtype")
|
|
109
|
-
return valid_ops
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def validate_ndigits(ndigits):
|
|
113
|
-
if not ndigits:
|
|
114
|
-
return
|
|
115
|
-
if not is_int(ndigits) or ndigits <= 0:
|
|
116
|
-
raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
|
|
117
|
-
if ndigits > MonitorConst.MAX_NDIGITS:
|
|
118
|
-
raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def validate_ranks(ranks):
|
|
122
|
-
if not isinstance(ranks, list):
|
|
123
|
-
raise TypeError("module_ranks should be a list")
|
|
124
|
-
for rank in ranks:
|
|
125
|
-
if not isinstance(rank, int) or isinstance(rank, bool):
|
|
126
|
-
raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def validate_targets(targets):
|
|
130
|
-
if not isinstance(targets, dict):
|
|
131
|
-
raise TypeError('targets in config.json should be a dict')
|
|
132
|
-
for module_name, field in targets.items():
|
|
133
|
-
if not isinstance(module_name, str):
|
|
134
|
-
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
135
|
-
if not isinstance(field, dict):
|
|
136
|
-
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def validate_print_struct(print_struct):
|
|
140
|
-
if not isinstance(print_struct, bool):
|
|
141
|
-
raise TypeError("print_struct should be a bool")
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def validate_ur_distribution(ur_distribution):
|
|
145
|
-
if not isinstance(ur_distribution, bool):
|
|
146
|
-
raise TypeError('ur_distribution should be a bool')
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def validate_xy_distribution(xy_distribution):
|
|
150
|
-
if not isinstance(xy_distribution, bool):
|
|
151
|
-
raise TypeError('xy_distribution should be a bool')
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def validate_wg_distribution(wg_distribution):
|
|
155
|
-
if not isinstance(wg_distribution, bool):
|
|
156
|
-
raise TypeError('wg_distribution should be a bool')
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def validate_mg_distribution(mg_distribution):
|
|
160
|
-
if not isinstance(mg_distribution, bool):
|
|
161
|
-
raise TypeError('mg_distribution should be a bool')
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def validate_param_distribution(param_distribution):
|
|
165
|
-
if not isinstance(param_distribution, bool):
|
|
166
|
-
raise TypeError('param_distribution should be a bool')
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def validate_cc_distribution(cc_distribution):
|
|
170
|
-
if not isinstance(cc_distribution, dict):
|
|
171
|
-
raise TypeError('cc_distribution should be a dictionary')
|
|
172
|
-
for key, value in cc_distribution.items():
|
|
173
|
-
if key == 'enable':
|
|
174
|
-
if not isinstance(value, bool):
|
|
175
|
-
raise TypeError('cc_distribution enable should be a bool')
|
|
176
|
-
elif key == 'cc_codeline':
|
|
177
|
-
if not isinstance(value, list):
|
|
178
|
-
raise TypeError('cc_distribution cc_codeline should be a list')
|
|
179
|
-
elif key == 'cc_pre_hook':
|
|
180
|
-
if not isinstance(value, bool):
|
|
181
|
-
raise TypeError('cc_distribution cc_pre_hook should be a bool')
|
|
182
|
-
elif key == 'cc_log_only':
|
|
183
|
-
if not isinstance(value, bool):
|
|
184
|
-
raise TypeError('cc_distribution cc_log_only should be a bool')
|
|
185
|
-
else:
|
|
186
|
-
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
def validate_squash_name(squash_name):
|
|
190
|
-
if not isinstance(squash_name, bool):
|
|
191
|
-
raise TypeError('squash_name should be a bool')
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
def validate_alert(alert):
|
|
195
|
-
if not isinstance(alert, dict):
|
|
196
|
-
raise TypeError('alert should be a dictionary')
|
|
197
|
-
rules = alert.get('rules')
|
|
198
|
-
if rules and isinstance(rules, list):
|
|
199
|
-
for rule in rules:
|
|
200
|
-
rule_name = rule.get("rule_name")
|
|
201
|
-
if rule_name and rule_name not in MonitorConst.RULE_NAME:
|
|
202
|
-
raise TypeError(f"{rule_name} is not supported")
|
|
203
|
-
args = rule.get("args")
|
|
204
|
-
if args and isinstance(args, dict):
|
|
205
|
-
threshold = args.get("threshold")
|
|
206
|
-
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
207
|
-
raise TypeError('threshold must be float and not less than 0')
|
|
208
|
-
dump = alert.get('dump')
|
|
209
|
-
if dump and not isinstance(dump, bool):
|
|
210
|
-
raise TypeError('dump must be bool.')
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
def validate_step_count_per_record(step_count_per_record):
|
|
214
|
-
if not is_int(step_count_per_record):
|
|
215
|
-
raise TypeError('step_count_per_record must be int.')
|
|
216
|
-
if step_count_per_record < 1:
|
|
217
|
-
raise ValueError("step_count_per_record must greater than 0")
|
|
218
|
-
if step_count_per_record > 1e6:
|
|
219
|
-
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def validate_dynamic_on(dynamic_on):
|
|
223
|
-
if not isinstance(dynamic_on, bool):
|
|
224
|
-
raise TypeError('dynamic_on should be a bool')
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
228
|
-
if not isinstance(monitor_mbs_grad, bool):
|
|
229
|
-
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
230
|
-
return False
|
|
231
|
-
return monitor_mbs_grad
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def validate_config(config):
|
|
235
|
-
config['ops'] = validate_ops(config.get('ops', []))
|
|
236
|
-
|
|
237
|
-
ndigits = config.get('ndigits')
|
|
238
|
-
validate_ndigits(ndigits)
|
|
239
|
-
|
|
240
|
-
eps = config.get('eps', 1e-8)
|
|
241
|
-
if not isinstance(eps, float):
|
|
242
|
-
raise TypeError("eps should be a float")
|
|
243
|
-
|
|
244
|
-
ranks = config.get("module_ranks", [])
|
|
245
|
-
validate_ranks(ranks)
|
|
246
|
-
|
|
247
|
-
targets = config.get("targets", {})
|
|
248
|
-
validate_targets(targets)
|
|
249
|
-
|
|
250
|
-
print_struct = config.get('print_struct', False)
|
|
251
|
-
validate_print_struct(print_struct)
|
|
252
|
-
|
|
253
|
-
ur_distribution = config.get('ur_distribution', False)
|
|
254
|
-
validate_ur_distribution(ur_distribution)
|
|
255
|
-
|
|
256
|
-
xy_distribution = config.get('xy_distribution', False)
|
|
257
|
-
validate_xy_distribution(xy_distribution)
|
|
258
|
-
|
|
259
|
-
wg_distribution = config.get('wg_distribution', False)
|
|
260
|
-
validate_wg_distribution(wg_distribution)
|
|
261
|
-
|
|
262
|
-
mg_distribution = config.get('mg_distribution', False)
|
|
263
|
-
validate_mg_distribution(mg_distribution)
|
|
264
|
-
|
|
265
|
-
param_distribution = config.get('param_distribution', False)
|
|
266
|
-
validate_param_distribution(param_distribution)
|
|
267
|
-
|
|
268
|
-
cc_distribution = config.get('cc_distribution', {})
|
|
269
|
-
validate_cc_distribution(cc_distribution)
|
|
270
|
-
|
|
271
|
-
alert = config.get('alert', {})
|
|
272
|
-
validate_alert(alert)
|
|
273
|
-
|
|
274
|
-
step_count_per_record = config.get('step_count_per_record', 1)
|
|
275
|
-
validate_step_count_per_record(step_count_per_record)
|
|
276
|
-
|
|
277
|
-
config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
|
|
278
|
-
MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
|
|
279
|
-
config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
|
|
280
|
-
MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
|
|
281
|
-
MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
|
|
282
|
-
config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
|
|
283
|
-
MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
|
|
284
|
-
|
|
285
|
-
squash_name = config.get('squash_name', True)
|
|
286
|
-
validate_squash_name(squash_name)
|
|
287
|
-
|
|
288
|
-
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
289
|
-
|
|
290
|
-
dynamic_on = config.get('dynamic_on', False)
|
|
291
|
-
validate_dynamic_on(dynamic_on)
|
|
292
|
-
|
|
293
|
-
if not targets:
|
|
294
|
-
if xy_distribution:
|
|
295
|
-
config["all_xy"] = True
|
|
296
|
-
config["targets"] = {"": {}}
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def time_str2time_digit(time_str):
|
|
300
|
-
time_format = '%b%d_%H-%M-%S'
|
|
301
|
-
if not isinstance(time_str, str):
|
|
302
|
-
raise TypeError(f"time_str:{time_str} should be a str")
|
|
303
|
-
try:
|
|
304
|
-
time_digit = datetime.strptime(time_str, time_format)
|
|
305
|
-
except Exception as e:
|
|
306
|
-
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
307
|
-
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
308
|
-
return time_digit
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
312
|
-
check_file_or_directory_path(monitor_path, isdir=True)
|
|
313
|
-
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
314
|
-
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
315
|
-
if time_start and time_end and time_start > time_end:
|
|
316
|
-
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
317
|
-
result = {}
|
|
318
|
-
for dirname in os.listdir(monitor_path):
|
|
319
|
-
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
320
|
-
if not match:
|
|
321
|
-
continue
|
|
322
|
-
time_tag = match.group(1)
|
|
323
|
-
rank = match.group(2)
|
|
324
|
-
target_time = time_str2time_digit(time_tag)
|
|
325
|
-
start_ok = time_start is None or target_time >= time_start
|
|
326
|
-
end_ok = time_end is None or target_time <= time_end
|
|
327
|
-
if start_ok and end_ok:
|
|
328
|
-
result[rank] = os.path.join(monitor_path, dirname)
|
|
329
|
-
return result
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
def chmod_tensorboard_dir(path):
|
|
333
|
-
"""
|
|
334
|
-
format配置为tensorboard时,需要补充文件权限设置
|
|
335
|
-
"""
|
|
336
|
-
try:
|
|
337
|
-
recursive_chmod(path)
|
|
338
|
-
except Exception as e:
|
|
339
|
-
logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
def validate_set_monitor(grad_acc_steps, start_iteration):
|
|
343
|
-
"""
|
|
344
|
-
validate parameters of set_monitor.
|
|
345
|
-
"""
|
|
346
|
-
grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
|
|
347
|
-
MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
|
|
348
|
-
|
|
349
|
-
start_iteration = validate_int_arg(start_iteration, "start_iteration",
|
|
350
|
-
MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
|
|
351
|
-
return grad_acc_steps, start_iteration
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
def validate_int_arg(value, name, minimum, default_value):
|
|
355
|
-
"""Validate int args, if any exception occurs, use the default value."""
|
|
356
|
-
if value is None:
|
|
357
|
-
return default_value
|
|
358
|
-
try:
|
|
359
|
-
if not is_int(value):
|
|
360
|
-
raise TypeError(f"{name} must be int")
|
|
361
|
-
if value < minimum:
|
|
362
|
-
raise ValueError(f"{name} must greater than {minimum}")
|
|
363
|
-
except Exception as e:
|
|
364
|
-
value = default_value
|
|
365
|
-
logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
|
|
366
|
-
return value
|
|
50
|
+
return res
|
|
@@ -104,7 +104,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
104
104
|
|
|
105
105
|
if not is_npu:
|
|
106
106
|
return
|
|
107
|
-
logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
|
|
107
|
+
logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}]')
|
|
108
108
|
|
|
109
109
|
if self.process_num > 0:
|
|
110
110
|
self.pool.close()
|
|
@@ -21,7 +21,7 @@ from datetime import datetime, timezone
|
|
|
21
21
|
import torch
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
24
|
-
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
|
|
24
|
+
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json, check_link, remove_path
|
|
25
25
|
from msprobe.pytorch.common.log import logger
|
|
26
26
|
|
|
27
27
|
|
|
@@ -113,6 +113,12 @@ def save_temp_summary(api_index, single_api_summary, path, lock):
|
|
|
113
113
|
try:
|
|
114
114
|
data = [api_index, single_api_summary]
|
|
115
115
|
save_json(summary_path, data, mode='a')
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f'save temp summary error:{e}')
|
|
118
|
+
try:
|
|
119
|
+
remove_path(summary_path)
|
|
120
|
+
except FileNotFoundError:
|
|
121
|
+
logger.error(f'file not found:{summary_path}')
|
|
116
122
|
finally:
|
|
117
123
|
lock.release()
|
|
118
124
|
|
|
@@ -37,8 +37,6 @@ try:
|
|
|
37
37
|
from rich.table import Table
|
|
38
38
|
from rich import print as rich_print
|
|
39
39
|
from rich.columns import Columns
|
|
40
|
-
|
|
41
|
-
install()
|
|
42
40
|
except ImportError as err:
|
|
43
41
|
install = None
|
|
44
42
|
Panel = None
|
|
@@ -228,7 +226,7 @@ class Util:
|
|
|
228
226
|
def check_path_valid(self, path):
|
|
229
227
|
path = self.path_strip(path)
|
|
230
228
|
if not path or not os.path.exists(path):
|
|
231
|
-
self.log.error("The path
|
|
229
|
+
self.log.error("The path does not exist.")
|
|
232
230
|
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
233
231
|
isdir = check_file_type(path) == FileCheckConst.DIR
|
|
234
232
|
check_file_or_directory_path(path, isdir=isdir)
|
|
@@ -236,7 +234,7 @@ class Util:
|
|
|
236
234
|
|
|
237
235
|
def check_files_in_path(self, path):
|
|
238
236
|
if os.path.isdir(path) and len(os.listdir(path)) == 0:
|
|
239
|
-
self.log.error("No files in
|
|
237
|
+
self.log.error("No files found in path.")
|
|
240
238
|
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
241
239
|
|
|
242
240
|
def npy_info(self, source_data):
|
|
@@ -66,7 +66,7 @@ def _compare_graph_result(input_param, args):
|
|
|
66
66
|
# 对两个数据进行构图
|
|
67
67
|
graph_n = _build_graph_info(input_param.get('npu_path'), args)
|
|
68
68
|
graph_b = _build_graph_info(input_param.get('bench_path'), args)
|
|
69
|
-
logger.info('Model graphs built successfully, start
|
|
69
|
+
logger.info('Model graphs built successfully, start comparing graphs...')
|
|
70
70
|
# 基于graph、stack和data进行比较
|
|
71
71
|
graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
|
|
72
72
|
# 增加micro step标记
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|