mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +37 -47
- msprobe/README.md +8 -5
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +64 -13
- msprobe/core/common/framework_adapter.py +10 -1
- msprobe/core/common/utils.py +17 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +2 -16
- msprobe/core/service.py +5 -16
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +0 -13
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +0 -70
- msprobe/pytorch/debugger/debugger_config.py +0 -10
- msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
- msprobe/pytorch/hook_module/api_register.py +5 -1
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +2 -11
- msprobe/visualization/builder/graph_builder.py +2 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
|
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
|
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
52
|
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
56
54
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
55
|
ExecParams
|
|
58
56
|
|
|
@@ -90,27 +88,22 @@ seed_all()
|
|
|
90
88
|
|
|
91
89
|
def run_ut(config):
|
|
92
90
|
logger.info("start UT test")
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
else:
|
|
97
|
-
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
98
|
-
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
91
|
+
|
|
92
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
93
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
99
94
|
|
|
100
95
|
if config.save_error_data:
|
|
101
96
|
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
102
97
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
103
98
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
api_name_set = set()
|
|
113
|
-
run_api_offline(config, compare, api_name_set)
|
|
99
|
+
|
|
100
|
+
csv_df = read_csv(config.result_csv_path)
|
|
101
|
+
try:
|
|
102
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
103
|
+
except IndexError:
|
|
104
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
105
|
+
api_name_set = set()
|
|
106
|
+
run_api_offline(config, compare, api_name_set)
|
|
114
107
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
115
108
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
116
109
|
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
164
157
|
gc.collect()
|
|
165
158
|
|
|
166
159
|
|
|
167
|
-
def run_api_online(config, compare):
|
|
168
|
-
attl = init_attl(config.online_config)
|
|
169
|
-
dispatcher = ConsumerDispatcher(compare=compare)
|
|
170
|
-
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
171
|
-
|
|
172
|
-
def tcp_communication_flow():
|
|
173
|
-
while True:
|
|
174
|
-
api_data = attl.recv()
|
|
175
|
-
if api_data == 'STOP_':
|
|
176
|
-
continue
|
|
177
|
-
if api_data == 'KILL_':
|
|
178
|
-
time.sleep(1)
|
|
179
|
-
logger.info("==========接收到STOP信号==========")
|
|
180
|
-
dispatcher.stop()
|
|
181
|
-
attl.stop_serve()
|
|
182
|
-
time.sleep(1)
|
|
183
|
-
break
|
|
184
|
-
if not isinstance(api_data, ApiData):
|
|
185
|
-
continue
|
|
186
|
-
api_full_name = api_data.name
|
|
187
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
188
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
189
|
-
continue
|
|
190
|
-
if api_data.rank in config.online_config.rank_list:
|
|
191
|
-
dispatcher.update_consume_queue(api_data)
|
|
192
|
-
|
|
193
|
-
def shared_storage_communication_flow():
|
|
194
|
-
flag_num = -1
|
|
195
|
-
while True:
|
|
196
|
-
api_data = attl.download()
|
|
197
|
-
if api_data == "start":
|
|
198
|
-
if flag_num == -1:
|
|
199
|
-
flag_num += 1
|
|
200
|
-
flag_num += 1
|
|
201
|
-
if api_data == "end":
|
|
202
|
-
flag_num -= 1
|
|
203
|
-
if flag_num == 0:
|
|
204
|
-
dispatcher.stop()
|
|
205
|
-
break
|
|
206
|
-
if not isinstance(api_data, ApiData):
|
|
207
|
-
continue
|
|
208
|
-
api_full_name = api_data.name
|
|
209
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
210
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
211
|
-
continue
|
|
212
|
-
if api_data.rank in config.online_config.rank_list:
|
|
213
|
-
dispatcher.update_consume_queue(api_data)
|
|
214
|
-
|
|
215
|
-
if config.online_config.nfs_path:
|
|
216
|
-
shared_storage_communication_flow()
|
|
217
|
-
else:
|
|
218
|
-
tcp_communication_flow()
|
|
219
|
-
|
|
220
|
-
|
|
221
160
|
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
222
161
|
"""
|
|
223
162
|
run api(api_name) if api_name not in black_list and in white_list.
|
|
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
315
254
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
316
255
|
|
|
317
256
|
|
|
318
|
-
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
319
|
-
in_fwd_data_list = []
|
|
320
|
-
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
321
|
-
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
322
|
-
in_fwd_data_list.append(args)
|
|
323
|
-
in_fwd_data_list.append(kwargs)
|
|
324
|
-
if kwargs.get("device"):
|
|
325
|
-
del kwargs["device"]
|
|
326
|
-
|
|
327
|
-
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
328
|
-
device_out = exec_api(device_exec_params)
|
|
329
|
-
device_out = move2device_exec(device_out, "cpu")
|
|
330
|
-
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
331
|
-
|
|
332
|
-
|
|
333
257
|
def check_need_grad(api_info_dict):
|
|
334
258
|
need_grad = True
|
|
335
259
|
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
|
|
|
389
313
|
return error_data_path
|
|
390
314
|
|
|
391
315
|
|
|
392
|
-
def init_attl(config):
|
|
393
|
-
"""config: OnlineConfig"""
|
|
394
|
-
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
395
|
-
connect_ip=config.host,
|
|
396
|
-
connect_port=config.port,
|
|
397
|
-
nfs_path=config.nfs_path,
|
|
398
|
-
tls_path=config.tls_path))
|
|
399
|
-
return attl
|
|
400
|
-
|
|
401
|
-
|
|
402
316
|
def _run_ut_parser(parser):
|
|
403
317
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
404
318
|
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
|
|
|
481
395
|
_run_ut_parser(parser)
|
|
482
396
|
args = parser.parse_args(sys.argv[1:])
|
|
483
397
|
run_ut_command(args)
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def checked_online_config(online_config):
|
|
487
|
-
if not online_config.is_online:
|
|
488
|
-
return
|
|
489
|
-
if not isinstance(online_config.is_online, bool):
|
|
490
|
-
raise ValueError("is_online must be bool type")
|
|
491
|
-
# rank_list
|
|
492
|
-
if not isinstance(online_config.rank_list, list):
|
|
493
|
-
raise ValueError("rank_list must be a list")
|
|
494
|
-
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
495
|
-
raise ValueError("All elements in rank_list must be integers")
|
|
496
|
-
|
|
497
|
-
# nfs_path
|
|
498
|
-
if online_config.nfs_path:
|
|
499
|
-
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
500
|
-
return
|
|
501
|
-
# tls_path
|
|
502
|
-
if online_config.tls_path:
|
|
503
|
-
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
504
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
505
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
506
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
-
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
-
if os.path.exists(crl_path):
|
|
509
|
-
check_file_or_directory_path(crl_path)
|
|
510
|
-
|
|
511
|
-
# host and port
|
|
512
|
-
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
513
|
-
raise Exception(f"host: {online_config.host} is invalid.")
|
|
514
|
-
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
515
|
-
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
516
398
|
|
|
517
399
|
|
|
518
400
|
def run_ut_command(args):
|
|
@@ -525,7 +407,7 @@ def run_ut_command(args):
|
|
|
525
407
|
else:
|
|
526
408
|
checker_config = CheckerConfig()
|
|
527
409
|
|
|
528
|
-
if not
|
|
410
|
+
if not args.api_info_file:
|
|
529
411
|
logger.error("Please provide api_info_file for offline run ut.")
|
|
530
412
|
raise Exception("Please provide api_info_file for offline run ut.")
|
|
531
413
|
|
|
@@ -588,8 +470,6 @@ def run_ut_command(args):
|
|
|
588
470
|
global UT_ERROR_DATA_DIR
|
|
589
471
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
590
472
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
591
|
-
online_config = checker_config.get_online_config()
|
|
592
|
-
checked_online_config(online_config)
|
|
593
473
|
config_params = {
|
|
594
474
|
'forward_content': forward_content,
|
|
595
475
|
'backward_content': backward_content,
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -337,56 +337,6 @@ def save_pt(tensor, filepath):
|
|
|
337
337
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
338
338
|
|
|
339
339
|
|
|
340
|
-
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
341
|
-
"""
|
|
342
|
-
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
343
|
-
It overrides the find_class method to add type checking functionality.
|
|
344
|
-
"""
|
|
345
|
-
allowed_types = [
|
|
346
|
-
"str",
|
|
347
|
-
"ApiData",
|
|
348
|
-
"OrderedDict",
|
|
349
|
-
"_rebuild_tensor_v2", # from torch.utils
|
|
350
|
-
"_load_from_bytes" # from torch.storage
|
|
351
|
-
]
|
|
352
|
-
|
|
353
|
-
def find_class(self, module, name):
|
|
354
|
-
"""
|
|
355
|
-
Method to find the class of the object to be unpickled.
|
|
356
|
-
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
357
|
-
"""
|
|
358
|
-
if name in self.allowed_types:
|
|
359
|
-
return super().find_class(module, name)
|
|
360
|
-
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
def save_pkl(tensor, filepath):
|
|
364
|
-
"""Save ApiData or str objection by pickle"""
|
|
365
|
-
check_path_before_create(filepath)
|
|
366
|
-
filepath = os.path.realpath(filepath)
|
|
367
|
-
try:
|
|
368
|
-
with FileOpen(filepath, 'wb') as f:
|
|
369
|
-
pickle.dump(tensor, f)
|
|
370
|
-
except Exception as e:
|
|
371
|
-
logger.error("Save pt file failed, please check according possible error causes: "
|
|
372
|
-
"1. out of disk space or disk error, "
|
|
373
|
-
"2. no permission to write files, etc.")
|
|
374
|
-
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
375
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def load_pkl(pt_path):
|
|
379
|
-
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
380
|
-
check_file_or_directory_path(pt_path)
|
|
381
|
-
pt_path = os.path.realpath(pt_path)
|
|
382
|
-
try:
|
|
383
|
-
with FileOpen(pt_path, 'rb') as f:
|
|
384
|
-
pt = TypeCheckingUnpickler(f).load()
|
|
385
|
-
except Exception as e:
|
|
386
|
-
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
387
|
-
return pt
|
|
388
|
-
|
|
389
|
-
|
|
390
340
|
def is_recomputation():
|
|
391
341
|
"""Check if the current operation is in the re-computation phase.
|
|
392
342
|
|
|
@@ -471,23 +421,3 @@ def register_forward_hook(module, forward_hook):
|
|
|
471
421
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
472
422
|
else:
|
|
473
423
|
module.register_forward_hook(forward_hook)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
def save_api_data(api_data):
|
|
477
|
-
"""Save data to io stream"""
|
|
478
|
-
try:
|
|
479
|
-
io_buff = io.BytesIO()
|
|
480
|
-
torch.save(api_data, io_buff)
|
|
481
|
-
except Exception as e:
|
|
482
|
-
raise RuntimeError(f"save api_data to io_buff failed") from e
|
|
483
|
-
return io_buff
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def load_api_data(api_data_bytes):
|
|
487
|
-
"""Load data from bytes stream"""
|
|
488
|
-
try:
|
|
489
|
-
buffer = io.BytesIO(api_data_bytes)
|
|
490
|
-
buffer = torch.load(buffer, map_location="cpu")
|
|
491
|
-
except Exception as e:
|
|
492
|
-
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
493
|
-
return buffer
|
|
@@ -48,16 +48,6 @@ class DebuggerConfig:
|
|
|
48
48
|
"max_sample": task_config.max_sample
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
-
self.online_run_ut = False
|
|
52
|
-
if self.task == Const.TENSOR:
|
|
53
|
-
# dump api tensor and collaborate with online run_ut
|
|
54
|
-
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
55
|
-
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
56
|
-
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
57
|
-
self.host = task_config.host if task_config.host else ""
|
|
58
|
-
self.port = task_config.port if task_config.port else -1
|
|
59
|
-
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
60
|
-
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
61
51
|
|
|
62
52
|
self.check()
|
|
63
53
|
self._check_statistics_config(task_config)
|
|
@@ -63,9 +63,11 @@ def wrap_forward_with_hook_safety(module):
|
|
|
63
63
|
except _StopRecomputationError as e:
|
|
64
64
|
exception_output = None
|
|
65
65
|
if len(module._forward_hooks.values()) > 0:
|
|
66
|
-
# msprobe的forward_hook
|
|
67
|
-
hook_fn
|
|
68
|
-
|
|
66
|
+
# 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
|
|
67
|
+
for hook_fn in module._forward_hooks.values():
|
|
68
|
+
if 'ModuleProcesser' in str(hook_fn):
|
|
69
|
+
hook_fn(module, args, kwargs, exception_output)
|
|
70
|
+
break
|
|
69
71
|
raise e
|
|
70
72
|
|
|
71
73
|
if torch_version_above_or_equal_21:
|
|
@@ -152,7 +154,13 @@ class ModuleProcesser:
|
|
|
152
154
|
modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
|
|
153
155
|
for index, modules_and_names in modules_and_names_with_index.items():
|
|
154
156
|
model = models if index == "-1" else models[int(index)]
|
|
157
|
+
|
|
158
|
+
model_list = []
|
|
155
159
|
for name, module in modules_and_names:
|
|
160
|
+
model_list.append((name, module))
|
|
161
|
+
|
|
162
|
+
is_verl = "verl" in sys.modules
|
|
163
|
+
for idx, (name, module) in enumerate(model_list):
|
|
156
164
|
if recursive and module == model:
|
|
157
165
|
continue
|
|
158
166
|
if not is_torch_nn_module(module):
|
|
@@ -163,6 +171,13 @@ class ModuleProcesser:
|
|
|
163
171
|
continue
|
|
164
172
|
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
165
173
|
continue
|
|
174
|
+
|
|
175
|
+
# verl 场景下跳过第一层和最后一层
|
|
176
|
+
if is_verl and (idx == 1 or idx == len(model_list) - 1):
|
|
177
|
+
logger.warning(f"The module {name} is the first or last layer in verl scenario, "
|
|
178
|
+
f"the data dump for this module will be skipped.")
|
|
179
|
+
continue
|
|
180
|
+
|
|
166
181
|
setattr(module, 'msprobe_hook', True)
|
|
167
182
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
168
183
|
prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
|
|
@@ -135,13 +135,17 @@ def redirect_wait():
|
|
|
135
135
|
store_func = dist_data_collect_func.pop(args[0])
|
|
136
136
|
store_func()
|
|
137
137
|
return
|
|
138
|
+
remove_value = None
|
|
138
139
|
for value in dist_batch_data_collect_func:
|
|
139
140
|
if args[0] in value[0]:
|
|
140
141
|
value[0].remove(args[0])
|
|
141
142
|
if len(value[0]) == 0:
|
|
142
143
|
store_func = value[1]
|
|
143
144
|
store_func()
|
|
144
|
-
|
|
145
|
+
remove_value = value
|
|
146
|
+
break
|
|
147
|
+
if remove_value:
|
|
148
|
+
dist_batch_data_collect_func.remove(remove_value)
|
|
145
149
|
|
|
146
150
|
return wrapped_wait
|
|
147
151
|
|
|
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
|
|
|
48
48
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
49
49
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
50
50
|
|
|
51
|
-
|
|
52
51
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
53
52
|
if not torch_version_above_or_equal_2:
|
|
54
53
|
raise ValueError("monitor require torch>=2.0")
|
|
55
54
|
|
|
56
|
-
|
|
57
55
|
FORMAT_MAPPING = {
|
|
58
56
|
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
59
57
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -150,15 +148,11 @@ class GradContext:
|
|
|
150
148
|
def __init__(self) -> None:
|
|
151
149
|
self.pre = {}
|
|
152
150
|
self.post = {}
|
|
153
|
-
self.acc_metric = {}
|
|
154
|
-
self.acc = {}
|
|
155
151
|
self.actv = {}
|
|
156
152
|
|
|
157
153
|
def reset(self):
|
|
158
154
|
self.pre.clear()
|
|
159
155
|
self.post.clear()
|
|
160
|
-
self.acc_metric.clear()
|
|
161
|
-
self.acc.clear()
|
|
162
156
|
self.actv.clear()
|
|
163
157
|
|
|
164
158
|
|
|
@@ -510,18 +504,8 @@ class TrainerMon:
|
|
|
510
504
|
if not self.wg_distribution:
|
|
511
505
|
return {}, {}
|
|
512
506
|
|
|
513
|
-
if self.weight_hooked:
|
|
514
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
515
|
-
|
|
516
507
|
get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
if self.weight_hooked:
|
|
520
|
-
unreduced_grad = self.grad_context.acc_metric
|
|
521
|
-
else:
|
|
522
|
-
unreduced_grad = self.grad_context.pre
|
|
523
|
-
|
|
524
|
-
return reduced_grad, unreduced_grad
|
|
508
|
+
return self.grad_context.post, self.grad_context.pre
|
|
525
509
|
|
|
526
510
|
def generate_xy_metrics(self):
|
|
527
511
|
actv = {}
|
|
@@ -529,7 +513,6 @@ class TrainerMon:
|
|
|
529
513
|
actv.update(fwd_context.actv)
|
|
530
514
|
|
|
531
515
|
actv_grad = self.grad_context.actv
|
|
532
|
-
|
|
533
516
|
return actv, actv_grad
|
|
534
517
|
|
|
535
518
|
def reload_xy(self, xy_distribution=False):
|
|
@@ -607,11 +590,8 @@ class TrainerMon:
|
|
|
607
590
|
if not self.wg_distribution:
|
|
608
591
|
return
|
|
609
592
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
use_micro_step=self.monitor_mbs_grad)
|
|
613
|
-
else:
|
|
614
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
593
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
|
|
594
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
615
595
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
616
596
|
|
|
617
597
|
def hook_optimizer(self, optimizer):
|
|
@@ -732,9 +712,9 @@ class TrainerMon:
|
|
|
732
712
|
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
733
713
|
if self.monitoring:
|
|
734
714
|
module_rank_valid = not self.module_rank_list or (
|
|
735
|
-
|
|
715
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
736
716
|
step_condition = (context.step >= self.start_step and (
|
|
737
|
-
|
|
717
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
738
718
|
if module_rank_valid and step_condition:
|
|
739
719
|
self.has_collect_times += 1
|
|
740
720
|
|
|
@@ -791,6 +771,7 @@ class TrainerMon:
|
|
|
791
771
|
hook(optimizer, args, kwargs)
|
|
792
772
|
step_final_hook(optimizer, args, kwargs)
|
|
793
773
|
return out
|
|
774
|
+
|
|
794
775
|
return wrapper
|
|
795
776
|
|
|
796
777
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
@@ -1013,11 +994,11 @@ class TrainerMon:
|
|
|
1013
994
|
vpp_stage + module_name,
|
|
1014
995
|
]:
|
|
1015
996
|
if pattern in l2_targets:
|
|
1016
|
-
return pattern
|
|
997
|
+
return pattern
|
|
1017
998
|
elif hook_name in ["linear_hook"]:
|
|
1018
999
|
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1019
1000
|
return ""
|
|
1020
|
-
|
|
1001
|
+
|
|
1021
1002
|
def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
|
|
1022
1003
|
if '_modules' not in module.__dict__:
|
|
1023
1004
|
# nothing to hook
|
|
@@ -1151,7 +1132,7 @@ class TrainerMon:
|
|
|
1151
1132
|
context.micro_step = 0
|
|
1152
1133
|
context.step += 1
|
|
1153
1134
|
return
|
|
1154
|
-
|
|
1135
|
+
|
|
1155
1136
|
def stack_hook(module, args, kwargs, module_output, name):
|
|
1156
1137
|
if module not in self.module_fwd_hook_context_by_module:
|
|
1157
1138
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
@@ -1221,7 +1202,7 @@ class TrainerMon:
|
|
|
1221
1202
|
if self.monitor_mbs_grad:
|
|
1222
1203
|
self._hook_weights()
|
|
1223
1204
|
return
|
|
1224
|
-
|
|
1205
|
+
|
|
1225
1206
|
self.optimizer_mon.patch_grad_sync(self)
|
|
1226
1207
|
|
|
1227
1208
|
if self.enable_megatron or self.enable_deepspeed:
|
|
@@ -1281,6 +1262,7 @@ class TrainerMon:
|
|
|
1281
1262
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1282
1263
|
out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
|
|
1283
1264
|
return out
|
|
1265
|
+
|
|
1284
1266
|
return wrapper
|
|
1285
1267
|
|
|
1286
1268
|
logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
|
|
@@ -1294,10 +1276,9 @@ class TrainerMon:
|
|
|
1294
1276
|
"""
|
|
1295
1277
|
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
1296
1278
|
"""
|
|
1297
|
-
context = self.grad_context
|
|
1298
1279
|
|
|
1299
1280
|
@torch.no_grad
|
|
1300
|
-
def param_hook(*args,
|
|
1281
|
+
def param_hook(*args, param, name):
|
|
1301
1282
|
key = name
|
|
1302
1283
|
if self.monitor_mbs_grad:
|
|
1303
1284
|
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
@@ -1305,14 +1286,15 @@ class TrainerMon:
|
|
|
1305
1286
|
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
1306
1287
|
self.register_param_call_id("param_hook", key)
|
|
1307
1288
|
param.micro_step += 1
|
|
1308
|
-
|
|
1289
|
+
grad_dict = {}
|
|
1309
1290
|
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1310
1291
|
if self.params_have_main_grad:
|
|
1311
1292
|
grad = param.main_grad
|
|
1312
1293
|
else:
|
|
1313
1294
|
grad = param.grad
|
|
1314
|
-
|
|
1295
|
+
grad_dict[key] = grad.clone()
|
|
1315
1296
|
|
|
1297
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1316
1298
|
if param.micro_step == self.micro_batch_number:
|
|
1317
1299
|
param.micro_step = 0
|
|
1318
1300
|
|
|
@@ -1322,7 +1304,7 @@ class TrainerMon:
|
|
|
1322
1304
|
param_tmp = param.expand_as(param)
|
|
1323
1305
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1324
1306
|
handle = grad_acc.register_hook(
|
|
1325
|
-
partial(param_hook,
|
|
1307
|
+
partial(param_hook, param=param, name=name))
|
|
1326
1308
|
self.grad_accs.append(grad_acc)
|
|
1327
1309
|
self.handles['wgrads'].append(handle)
|
|
1328
1310
|
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
|
|
|
35
35
|
class TensorConfig(BaseConfig):
|
|
36
36
|
def __init__(self, json_config):
|
|
37
37
|
super().__init__(json_config)
|
|
38
|
-
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
39
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
40
|
-
self.host = json_config.get("host", "")
|
|
41
|
-
self.port = json_config.get("port", -1)
|
|
42
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
-
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
44
38
|
self.check_config()
|
|
45
39
|
self._check_summary_mode()
|
|
46
40
|
self._check_file_format()
|
|
47
|
-
|
|
48
|
-
self._check_online_run_ut()
|
|
41
|
+
|
|
49
42
|
|
|
50
43
|
def _check_file_format(self):
|
|
51
44
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
52
45
|
raise Exception("file_format is invalid")
|
|
53
46
|
|
|
54
|
-
def _check_online_run_ut(self):
|
|
55
|
-
if not isinstance(self.online_run_ut, bool):
|
|
56
|
-
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
57
|
-
|
|
58
|
-
if not isinstance(self.online_run_ut_recompute, bool):
|
|
59
|
-
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
60
|
-
|
|
61
|
-
if self.nfs_path:
|
|
62
|
-
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
63
|
-
return
|
|
64
|
-
|
|
65
|
-
if self.tls_path:
|
|
66
|
-
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
-
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
-
if os.path.exists(crl_path):
|
|
72
|
-
check_file_or_directory_path(crl_path)
|
|
73
|
-
|
|
74
|
-
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
75
|
-
raise Exception(f"host: {self.host} is invalid.")
|
|
76
|
-
|
|
77
|
-
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
78
|
-
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
79
|
-
|
|
80
47
|
|
|
81
48
|
class StatisticsConfig(BaseConfig):
|
|
82
49
|
def __init__(self, json_config):
|
|
@@ -257,12 +224,7 @@ class RunUTConfig(BaseConfig):
|
|
|
257
224
|
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
258
225
|
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
259
226
|
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
260
|
-
|
|
261
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
262
|
-
self.host = json_config.get("host", "")
|
|
263
|
-
self.port = json_config.get("port", -1)
|
|
264
|
-
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
265
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
227
|
+
|
|
266
228
|
self.check_run_ut_config()
|
|
267
229
|
|
|
268
230
|
@classmethod
|
|
@@ -280,22 +242,11 @@ class RunUTConfig(BaseConfig):
|
|
|
280
242
|
if not os.path.exists(error_data_path):
|
|
281
243
|
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
282
244
|
|
|
283
|
-
@classmethod
|
|
284
|
-
def check_nfs_path_config(cls, nfs_path):
|
|
285
|
-
if nfs_path:
|
|
286
|
-
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
287
|
-
|
|
288
|
-
@classmethod
|
|
289
|
-
def check_tls_path_config(cls, tls_path):
|
|
290
|
-
if tls_path:
|
|
291
|
-
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
292
245
|
|
|
293
246
|
def check_run_ut_config(self):
|
|
294
247
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
295
248
|
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
296
249
|
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
297
|
-
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
298
|
-
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
299
250
|
|
|
300
251
|
|
|
301
252
|
class GradToolConfig(BaseConfig):
|
|
@@ -15,9 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.utils import Const
|
|
17
17
|
from msprobe.core.service import BaseService
|
|
18
|
-
from msprobe.pytorch.attl_manager import ATTLManager
|
|
19
18
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
19
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
21
20
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
22
21
|
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
|
|
23
22
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
@@ -25,9 +24,6 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
|
25
24
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
26
25
|
from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
|
|
27
26
|
|
|
28
|
-
if torch_version_above_or_equal_2:
|
|
29
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
-
|
|
31
27
|
|
|
32
28
|
class PytorchService(BaseService):
|
|
33
29
|
@property
|
|
@@ -45,12 +41,10 @@ class PytorchService(BaseService):
|
|
|
45
41
|
self.logger = logger
|
|
46
42
|
self.api_register = get_api_register()
|
|
47
43
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
48
|
-
self.
|
|
49
|
-
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
44
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config)
|
|
50
45
|
self.api_template = ApiTemplate
|
|
51
46
|
|
|
52
47
|
def _register_hook(self):
|
|
53
|
-
self.attl_manager.attl_init()
|
|
54
48
|
if self._is_mix_level:
|
|
55
49
|
register_optimizer_hook(self.data_collector)
|
|
56
50
|
|
|
@@ -65,9 +59,6 @@ class PytorchService(BaseService):
|
|
|
65
59
|
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
66
60
|
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
67
61
|
|
|
68
|
-
def _run_ut_dispatch(self, status):
|
|
69
|
-
if torch_version_above_or_equal_2:
|
|
70
|
-
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
71
62
|
|
|
72
63
|
def _reset_status(self):
|
|
73
64
|
super()._reset_status()
|
|
@@ -298,8 +298,8 @@ class GraphBuilder:
|
|
|
298
298
|
no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
|
|
299
299
|
if not no_recompute_map:
|
|
300
300
|
return
|
|
301
|
-
#
|
|
302
|
-
no_recompute_ids_b =
|
|
301
|
+
# 拷贝非重计算节点字典用于反向模式
|
|
302
|
+
no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
|
|
303
303
|
|
|
304
304
|
del_indexes = []
|
|
305
305
|
for node_id, id_prefix in recompute_map.items():
|