mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
- msprobe/README.md +8 -5
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +64 -13
- msprobe/core/common/framework_adapter.py +10 -1
- msprobe/core/common/utils.py +17 -0
- msprobe/core/compare/utils.py +26 -6
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +2 -16
- msprobe/core/service.py +5 -16
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +0 -13
- msprobe/docs/05.data_dump_PyTorch.md +1 -1
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
- msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/19.monitor.md +4 -4
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +0 -70
- msprobe/pytorch/debugger/debugger_config.py +0 -10
- msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
- msprobe/pytorch/hook_module/api_register.py +14 -3
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +10 -14
- msprobe/visualization/builder/graph_builder.py +2 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/db_utils.py +42 -18
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/visualization/graph_service.py +20 -10
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
|
@@ -39,7 +39,12 @@ from msprobe.core.common.const import FileCheckConst, Const
|
|
|
39
39
|
from msprobe.core.common.utils import CompareException
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
def split_json_file(input_file, num_splits, filter_api):
|
|
42
|
+
def split_json_file(input_file, num_splits, filter_api, device_id):
|
|
43
|
+
max_processes = len(device_id) * 8
|
|
44
|
+
if num_splits > max_processes:
|
|
45
|
+
logger.warning(f"A device supports a maximum of 8 processes. "
|
|
46
|
+
f"The total number of processes exceeds the limit, and it is set to {max_processes}.")
|
|
47
|
+
num_splits = max_processes
|
|
43
48
|
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
44
49
|
input_dir = os.path.dirname(os.path.abspath(input_file))
|
|
45
50
|
if filter_api:
|
|
@@ -88,7 +93,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
88
93
|
logger.error(f"File not found or could not be deleted: {file}")
|
|
89
94
|
msg = 'ERROR: Split json file failed, please check the input file and try again.'
|
|
90
95
|
raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
|
|
91
|
-
return split_files, total_items
|
|
96
|
+
return split_files, total_items, num_splits
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
def signal_handler(signum, frame):
|
|
@@ -127,7 +132,8 @@ def run_parallel_ut(config):
|
|
|
127
132
|
def read_process_output(process):
|
|
128
133
|
try:
|
|
129
134
|
while True:
|
|
130
|
-
|
|
135
|
+
# 子进程标准输出流与进程本身状态是分开的,因此增加判断。子进程返回值非None表示子进程结束,标准输出为None表示结束。
|
|
136
|
+
if process.poll() is not None or process.stdout is None:
|
|
131
137
|
break
|
|
132
138
|
output = process.stdout.readline()
|
|
133
139
|
if output == '':
|
|
@@ -175,12 +181,17 @@ def run_parallel_ut(config):
|
|
|
175
181
|
|
|
176
182
|
try:
|
|
177
183
|
for process in processes:
|
|
178
|
-
process.
|
|
184
|
+
process.wait() # wait仅阻塞,不捕获标准输出和标准错误,原communicate不仅阻塞,而且捕获标准输出和标准错误
|
|
179
185
|
except KeyboardInterrupt:
|
|
180
186
|
logger.warning("Interrupted by user, terminating processes and cleaning up...")
|
|
181
187
|
except Exception as e:
|
|
182
188
|
logger.error(f"An unexpected error occurred: {e}")
|
|
183
189
|
finally:
|
|
190
|
+
# 最后再更新一次进度条,避免因缓存写入等原因子进程结束而进度未刷新的问题
|
|
191
|
+
if wait_for_file_write_complete(config.result_csv_path):
|
|
192
|
+
result_file = read_csv(config.result_csv_path)
|
|
193
|
+
completed_items = len(result_file)
|
|
194
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
184
195
|
if progress_bar.n < config.total_items:
|
|
185
196
|
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
|
|
186
197
|
"the result CSV file will be utilized to resume the UT task.")
|
|
@@ -195,6 +206,22 @@ def run_parallel_ut(config):
|
|
|
195
206
|
logger.error(f"An unexpected error occurred: {e}")
|
|
196
207
|
|
|
197
208
|
|
|
209
|
+
def wait_for_file_write_complete(file_path, timeout=3600):
|
|
210
|
+
last_size = 0
|
|
211
|
+
start_time = time.time() # 记录开始时间
|
|
212
|
+
while True:
|
|
213
|
+
current_size = os.path.getsize(file_path)
|
|
214
|
+
# 检查是否文件大小未变化
|
|
215
|
+
if current_size == last_size:
|
|
216
|
+
return True # 文件写入完成,返回 True
|
|
217
|
+
last_size = current_size
|
|
218
|
+
# 检查是否超时
|
|
219
|
+
if time.time() - start_time > timeout:
|
|
220
|
+
logger.error("write the result csv file timeout.")
|
|
221
|
+
return False # 超时,返回 False
|
|
222
|
+
time.sleep(0.1) # 适当的延时
|
|
223
|
+
|
|
224
|
+
|
|
198
225
|
def prepare_config(args):
|
|
199
226
|
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
200
227
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
@@ -203,7 +230,9 @@ def prepare_config(args):
|
|
|
203
230
|
create_directory(out_path)
|
|
204
231
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
205
232
|
out_path = out_path_checker.common_check()
|
|
206
|
-
split_files, total_items = split_json_file(api_info, args.num_splits,
|
|
233
|
+
split_files, total_items, modified_num_splits = split_json_file(api_info, args.num_splits,
|
|
234
|
+
args.filter_api, args.device_id)
|
|
235
|
+
args.num_splits = modified_num_splits
|
|
207
236
|
config_path = args.config_path if args.config_path else None
|
|
208
237
|
if config_path:
|
|
209
238
|
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
|
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
52
|
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
56
54
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
55
|
ExecParams
|
|
58
56
|
|
|
@@ -90,27 +88,22 @@ seed_all()
|
|
|
90
88
|
|
|
91
89
|
def run_ut(config):
|
|
92
90
|
logger.info("start UT test")
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
else:
|
|
97
|
-
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
98
|
-
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
91
|
+
|
|
92
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
93
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
99
94
|
|
|
100
95
|
if config.save_error_data:
|
|
101
96
|
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
102
97
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
103
98
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
api_name_set = set()
|
|
113
|
-
run_api_offline(config, compare, api_name_set)
|
|
99
|
+
|
|
100
|
+
csv_df = read_csv(config.result_csv_path)
|
|
101
|
+
try:
|
|
102
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
103
|
+
except IndexError:
|
|
104
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
105
|
+
api_name_set = set()
|
|
106
|
+
run_api_offline(config, compare, api_name_set)
|
|
114
107
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
115
108
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
116
109
|
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
164
157
|
gc.collect()
|
|
165
158
|
|
|
166
159
|
|
|
167
|
-
def run_api_online(config, compare):
|
|
168
|
-
attl = init_attl(config.online_config)
|
|
169
|
-
dispatcher = ConsumerDispatcher(compare=compare)
|
|
170
|
-
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
171
|
-
|
|
172
|
-
def tcp_communication_flow():
|
|
173
|
-
while True:
|
|
174
|
-
api_data = attl.recv()
|
|
175
|
-
if api_data == 'STOP_':
|
|
176
|
-
continue
|
|
177
|
-
if api_data == 'KILL_':
|
|
178
|
-
time.sleep(1)
|
|
179
|
-
logger.info("==========接收到STOP信号==========")
|
|
180
|
-
dispatcher.stop()
|
|
181
|
-
attl.stop_serve()
|
|
182
|
-
time.sleep(1)
|
|
183
|
-
break
|
|
184
|
-
if not isinstance(api_data, ApiData):
|
|
185
|
-
continue
|
|
186
|
-
api_full_name = api_data.name
|
|
187
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
188
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
189
|
-
continue
|
|
190
|
-
if api_data.rank in config.online_config.rank_list:
|
|
191
|
-
dispatcher.update_consume_queue(api_data)
|
|
192
|
-
|
|
193
|
-
def shared_storage_communication_flow():
|
|
194
|
-
flag_num = -1
|
|
195
|
-
while True:
|
|
196
|
-
api_data = attl.download()
|
|
197
|
-
if api_data == "start":
|
|
198
|
-
if flag_num == -1:
|
|
199
|
-
flag_num += 1
|
|
200
|
-
flag_num += 1
|
|
201
|
-
if api_data == "end":
|
|
202
|
-
flag_num -= 1
|
|
203
|
-
if flag_num == 0:
|
|
204
|
-
dispatcher.stop()
|
|
205
|
-
break
|
|
206
|
-
if not isinstance(api_data, ApiData):
|
|
207
|
-
continue
|
|
208
|
-
api_full_name = api_data.name
|
|
209
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
210
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
211
|
-
continue
|
|
212
|
-
if api_data.rank in config.online_config.rank_list:
|
|
213
|
-
dispatcher.update_consume_queue(api_data)
|
|
214
|
-
|
|
215
|
-
if config.online_config.nfs_path:
|
|
216
|
-
shared_storage_communication_flow()
|
|
217
|
-
else:
|
|
218
|
-
tcp_communication_flow()
|
|
219
|
-
|
|
220
|
-
|
|
221
160
|
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
222
161
|
"""
|
|
223
162
|
run api(api_name) if api_name not in black_list and in white_list.
|
|
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
315
254
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
316
255
|
|
|
317
256
|
|
|
318
|
-
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
319
|
-
in_fwd_data_list = []
|
|
320
|
-
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
321
|
-
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
322
|
-
in_fwd_data_list.append(args)
|
|
323
|
-
in_fwd_data_list.append(kwargs)
|
|
324
|
-
if kwargs.get("device"):
|
|
325
|
-
del kwargs["device"]
|
|
326
|
-
|
|
327
|
-
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
328
|
-
device_out = exec_api(device_exec_params)
|
|
329
|
-
device_out = move2device_exec(device_out, "cpu")
|
|
330
|
-
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
331
|
-
|
|
332
|
-
|
|
333
257
|
def check_need_grad(api_info_dict):
|
|
334
258
|
need_grad = True
|
|
335
259
|
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
|
|
|
389
313
|
return error_data_path
|
|
390
314
|
|
|
391
315
|
|
|
392
|
-
def init_attl(config):
|
|
393
|
-
"""config: OnlineConfig"""
|
|
394
|
-
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
395
|
-
connect_ip=config.host,
|
|
396
|
-
connect_port=config.port,
|
|
397
|
-
nfs_path=config.nfs_path,
|
|
398
|
-
tls_path=config.tls_path))
|
|
399
|
-
return attl
|
|
400
|
-
|
|
401
|
-
|
|
402
316
|
def _run_ut_parser(parser):
|
|
403
317
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
404
318
|
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
|
|
|
481
395
|
_run_ut_parser(parser)
|
|
482
396
|
args = parser.parse_args(sys.argv[1:])
|
|
483
397
|
run_ut_command(args)
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def checked_online_config(online_config):
|
|
487
|
-
if not online_config.is_online:
|
|
488
|
-
return
|
|
489
|
-
if not isinstance(online_config.is_online, bool):
|
|
490
|
-
raise ValueError("is_online must be bool type")
|
|
491
|
-
# rank_list
|
|
492
|
-
if not isinstance(online_config.rank_list, list):
|
|
493
|
-
raise ValueError("rank_list must be a list")
|
|
494
|
-
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
495
|
-
raise ValueError("All elements in rank_list must be integers")
|
|
496
|
-
|
|
497
|
-
# nfs_path
|
|
498
|
-
if online_config.nfs_path:
|
|
499
|
-
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
500
|
-
return
|
|
501
|
-
# tls_path
|
|
502
|
-
if online_config.tls_path:
|
|
503
|
-
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
504
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
505
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
506
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
-
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
-
if os.path.exists(crl_path):
|
|
509
|
-
check_file_or_directory_path(crl_path)
|
|
510
|
-
|
|
511
|
-
# host and port
|
|
512
|
-
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
513
|
-
raise Exception(f"host: {online_config.host} is invalid.")
|
|
514
|
-
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
515
|
-
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
516
398
|
|
|
517
399
|
|
|
518
400
|
def run_ut_command(args):
|
|
@@ -525,7 +407,7 @@ def run_ut_command(args):
|
|
|
525
407
|
else:
|
|
526
408
|
checker_config = CheckerConfig()
|
|
527
409
|
|
|
528
|
-
if not
|
|
410
|
+
if not args.api_info_file:
|
|
529
411
|
logger.error("Please provide api_info_file for offline run ut.")
|
|
530
412
|
raise Exception("Please provide api_info_file for offline run ut.")
|
|
531
413
|
|
|
@@ -588,8 +470,6 @@ def run_ut_command(args):
|
|
|
588
470
|
global UT_ERROR_DATA_DIR
|
|
589
471
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
590
472
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
591
|
-
online_config = checker_config.get_online_config()
|
|
592
|
-
checked_online_config(online_config)
|
|
593
473
|
config_params = {
|
|
594
474
|
'forward_content': forward_content,
|
|
595
475
|
'backward_content': backward_content,
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -337,56 +337,6 @@ def save_pt(tensor, filepath):
|
|
|
337
337
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
338
338
|
|
|
339
339
|
|
|
340
|
-
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
341
|
-
"""
|
|
342
|
-
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
343
|
-
It overrides the find_class method to add type checking functionality.
|
|
344
|
-
"""
|
|
345
|
-
allowed_types = [
|
|
346
|
-
"str",
|
|
347
|
-
"ApiData",
|
|
348
|
-
"OrderedDict",
|
|
349
|
-
"_rebuild_tensor_v2", # from torch.utils
|
|
350
|
-
"_load_from_bytes" # from torch.storage
|
|
351
|
-
]
|
|
352
|
-
|
|
353
|
-
def find_class(self, module, name):
|
|
354
|
-
"""
|
|
355
|
-
Method to find the class of the object to be unpickled.
|
|
356
|
-
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
357
|
-
"""
|
|
358
|
-
if name in self.allowed_types:
|
|
359
|
-
return super().find_class(module, name)
|
|
360
|
-
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
def save_pkl(tensor, filepath):
|
|
364
|
-
"""Save ApiData or str objection by pickle"""
|
|
365
|
-
check_path_before_create(filepath)
|
|
366
|
-
filepath = os.path.realpath(filepath)
|
|
367
|
-
try:
|
|
368
|
-
with FileOpen(filepath, 'wb') as f:
|
|
369
|
-
pickle.dump(tensor, f)
|
|
370
|
-
except Exception as e:
|
|
371
|
-
logger.error("Save pt file failed, please check according possible error causes: "
|
|
372
|
-
"1. out of disk space or disk error, "
|
|
373
|
-
"2. no permission to write files, etc.")
|
|
374
|
-
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
375
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def load_pkl(pt_path):
|
|
379
|
-
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
380
|
-
check_file_or_directory_path(pt_path)
|
|
381
|
-
pt_path = os.path.realpath(pt_path)
|
|
382
|
-
try:
|
|
383
|
-
with FileOpen(pt_path, 'rb') as f:
|
|
384
|
-
pt = TypeCheckingUnpickler(f).load()
|
|
385
|
-
except Exception as e:
|
|
386
|
-
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
387
|
-
return pt
|
|
388
|
-
|
|
389
|
-
|
|
390
340
|
def is_recomputation():
|
|
391
341
|
"""Check if the current operation is in the re-computation phase.
|
|
392
342
|
|
|
@@ -471,23 +421,3 @@ def register_forward_hook(module, forward_hook):
|
|
|
471
421
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
472
422
|
else:
|
|
473
423
|
module.register_forward_hook(forward_hook)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
def save_api_data(api_data):
|
|
477
|
-
"""Save data to io stream"""
|
|
478
|
-
try:
|
|
479
|
-
io_buff = io.BytesIO()
|
|
480
|
-
torch.save(api_data, io_buff)
|
|
481
|
-
except Exception as e:
|
|
482
|
-
raise RuntimeError(f"save api_data to io_buff failed") from e
|
|
483
|
-
return io_buff
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def load_api_data(api_data_bytes):
|
|
487
|
-
"""Load data from bytes stream"""
|
|
488
|
-
try:
|
|
489
|
-
buffer = io.BytesIO(api_data_bytes)
|
|
490
|
-
buffer = torch.load(buffer, map_location="cpu")
|
|
491
|
-
except Exception as e:
|
|
492
|
-
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
493
|
-
return buffer
|
|
@@ -48,16 +48,6 @@ class DebuggerConfig:
|
|
|
48
48
|
"max_sample": task_config.max_sample
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
-
self.online_run_ut = False
|
|
52
|
-
if self.task == Const.TENSOR:
|
|
53
|
-
# dump api tensor and collaborate with online run_ut
|
|
54
|
-
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
55
|
-
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
56
|
-
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
57
|
-
self.host = task_config.host if task_config.host else ""
|
|
58
|
-
self.port = task_config.port if task_config.port else -1
|
|
59
|
-
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
60
|
-
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
61
51
|
|
|
62
52
|
self.check()
|
|
63
53
|
self._check_statistics_config(task_config)
|
|
@@ -63,9 +63,11 @@ def wrap_forward_with_hook_safety(module):
|
|
|
63
63
|
except _StopRecomputationError as e:
|
|
64
64
|
exception_output = None
|
|
65
65
|
if len(module._forward_hooks.values()) > 0:
|
|
66
|
-
# msprobe的forward_hook
|
|
67
|
-
hook_fn
|
|
68
|
-
|
|
66
|
+
# 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
|
|
67
|
+
for hook_fn in module._forward_hooks.values():
|
|
68
|
+
if 'ModuleProcesser' in str(hook_fn):
|
|
69
|
+
hook_fn(module, args, kwargs, exception_output)
|
|
70
|
+
break
|
|
69
71
|
raise e
|
|
70
72
|
|
|
71
73
|
if torch_version_above_or_equal_21:
|
|
@@ -152,7 +154,13 @@ class ModuleProcesser:
|
|
|
152
154
|
modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
|
|
153
155
|
for index, modules_and_names in modules_and_names_with_index.items():
|
|
154
156
|
model = models if index == "-1" else models[int(index)]
|
|
157
|
+
|
|
158
|
+
model_list = []
|
|
155
159
|
for name, module in modules_and_names:
|
|
160
|
+
model_list.append((name, module))
|
|
161
|
+
|
|
162
|
+
is_verl = "verl" in sys.modules
|
|
163
|
+
for idx, (name, module) in enumerate(model_list):
|
|
156
164
|
if recursive and module == model:
|
|
157
165
|
continue
|
|
158
166
|
if not is_torch_nn_module(module):
|
|
@@ -163,6 +171,13 @@ class ModuleProcesser:
|
|
|
163
171
|
continue
|
|
164
172
|
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
165
173
|
continue
|
|
174
|
+
|
|
175
|
+
# verl 场景下跳过第一层和最后一层
|
|
176
|
+
if is_verl and (idx == 1 or idx == len(model_list) - 1):
|
|
177
|
+
logger.warning(f"The module {name} is the first or last layer in verl scenario, "
|
|
178
|
+
f"the data dump for this module will be skipped.")
|
|
179
|
+
continue
|
|
180
|
+
|
|
166
181
|
setattr(module, 'msprobe_hook', True)
|
|
167
182
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
168
183
|
prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
|
|
@@ -22,6 +22,7 @@ import torch.distributed as dist
|
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.common.file_utils import load_yaml
|
|
25
|
+
from msprobe.core.common.runtime import Runtime
|
|
25
26
|
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
26
27
|
from msprobe.pytorch.common.log import logger
|
|
27
28
|
from msprobe.pytorch.common.utils import (
|
|
@@ -91,6 +92,12 @@ _inner_used_api = {
|
|
|
91
92
|
}
|
|
92
93
|
|
|
93
94
|
|
|
95
|
+
def reset_dist_collect_func():
|
|
96
|
+
global dist_data_collect_func, dist_batch_data_collect_func
|
|
97
|
+
dist_data_collect_func.clear()
|
|
98
|
+
dist_batch_data_collect_func.clear()
|
|
99
|
+
|
|
100
|
+
|
|
94
101
|
@parameter_adapter
|
|
95
102
|
def tensor_module_forward(module, *args, **kwargs):
|
|
96
103
|
return module.api_func(*args, **kwargs)
|
|
@@ -114,9 +121,9 @@ def dist_module_forward(module, *args, **kwargs):
|
|
|
114
121
|
|
|
115
122
|
return store_data
|
|
116
123
|
|
|
117
|
-
if use_async_op_flag or module.api_name in ['isend', 'irecv']:
|
|
124
|
+
if Runtime.is_running and (use_async_op_flag or module.api_name in ['isend', 'irecv']):
|
|
118
125
|
dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook)
|
|
119
|
-
if module.api_name == 'batch_isend_irecv':
|
|
126
|
+
if Runtime.is_running and module.api_name == 'batch_isend_irecv':
|
|
120
127
|
dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)])
|
|
121
128
|
return handle
|
|
122
129
|
|
|
@@ -135,13 +142,17 @@ def redirect_wait():
|
|
|
135
142
|
store_func = dist_data_collect_func.pop(args[0])
|
|
136
143
|
store_func()
|
|
137
144
|
return
|
|
145
|
+
remove_value = None
|
|
138
146
|
for value in dist_batch_data_collect_func:
|
|
139
147
|
if args[0] in value[0]:
|
|
140
148
|
value[0].remove(args[0])
|
|
141
149
|
if len(value[0]) == 0:
|
|
142
150
|
store_func = value[1]
|
|
143
151
|
store_func()
|
|
144
|
-
|
|
152
|
+
remove_value = value
|
|
153
|
+
break
|
|
154
|
+
if remove_value:
|
|
155
|
+
dist_batch_data_collect_func.remove(remove_value)
|
|
145
156
|
|
|
146
157
|
return wrapped_wait
|
|
147
158
|
|
|
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
|
|
|
48
48
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
49
49
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
50
50
|
|
|
51
|
-
|
|
52
51
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
53
52
|
if not torch_version_above_or_equal_2:
|
|
54
53
|
raise ValueError("monitor require torch>=2.0")
|
|
55
54
|
|
|
56
|
-
|
|
57
55
|
FORMAT_MAPPING = {
|
|
58
56
|
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
59
57
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -150,15 +148,11 @@ class GradContext:
|
|
|
150
148
|
def __init__(self) -> None:
|
|
151
149
|
self.pre = {}
|
|
152
150
|
self.post = {}
|
|
153
|
-
self.acc_metric = {}
|
|
154
|
-
self.acc = {}
|
|
155
151
|
self.actv = {}
|
|
156
152
|
|
|
157
153
|
def reset(self):
|
|
158
154
|
self.pre.clear()
|
|
159
155
|
self.post.clear()
|
|
160
|
-
self.acc_metric.clear()
|
|
161
|
-
self.acc.clear()
|
|
162
156
|
self.actv.clear()
|
|
163
157
|
|
|
164
158
|
|
|
@@ -510,18 +504,8 @@ class TrainerMon:
|
|
|
510
504
|
if not self.wg_distribution:
|
|
511
505
|
return {}, {}
|
|
512
506
|
|
|
513
|
-
if self.weight_hooked:
|
|
514
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
515
|
-
|
|
516
507
|
get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
if self.weight_hooked:
|
|
520
|
-
unreduced_grad = self.grad_context.acc_metric
|
|
521
|
-
else:
|
|
522
|
-
unreduced_grad = self.grad_context.pre
|
|
523
|
-
|
|
524
|
-
return reduced_grad, unreduced_grad
|
|
508
|
+
return self.grad_context.post, self.grad_context.pre
|
|
525
509
|
|
|
526
510
|
def generate_xy_metrics(self):
|
|
527
511
|
actv = {}
|
|
@@ -529,7 +513,6 @@ class TrainerMon:
|
|
|
529
513
|
actv.update(fwd_context.actv)
|
|
530
514
|
|
|
531
515
|
actv_grad = self.grad_context.actv
|
|
532
|
-
|
|
533
516
|
return actv, actv_grad
|
|
534
517
|
|
|
535
518
|
def reload_xy(self, xy_distribution=False):
|
|
@@ -607,11 +590,8 @@ class TrainerMon:
|
|
|
607
590
|
if not self.wg_distribution:
|
|
608
591
|
return
|
|
609
592
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
use_micro_step=self.monitor_mbs_grad)
|
|
613
|
-
else:
|
|
614
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
593
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
|
|
594
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
615
595
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
616
596
|
|
|
617
597
|
def hook_optimizer(self, optimizer):
|
|
@@ -732,9 +712,9 @@ class TrainerMon:
|
|
|
732
712
|
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
733
713
|
if self.monitoring:
|
|
734
714
|
module_rank_valid = not self.module_rank_list or (
|
|
735
|
-
|
|
715
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
736
716
|
step_condition = (context.step >= self.start_step and (
|
|
737
|
-
|
|
717
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
738
718
|
if module_rank_valid and step_condition:
|
|
739
719
|
self.has_collect_times += 1
|
|
740
720
|
|
|
@@ -791,6 +771,7 @@ class TrainerMon:
|
|
|
791
771
|
hook(optimizer, args, kwargs)
|
|
792
772
|
step_final_hook(optimizer, args, kwargs)
|
|
793
773
|
return out
|
|
774
|
+
|
|
794
775
|
return wrapper
|
|
795
776
|
|
|
796
777
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
@@ -1013,11 +994,11 @@ class TrainerMon:
|
|
|
1013
994
|
vpp_stage + module_name,
|
|
1014
995
|
]:
|
|
1015
996
|
if pattern in l2_targets:
|
|
1016
|
-
return pattern
|
|
997
|
+
return pattern
|
|
1017
998
|
elif hook_name in ["linear_hook"]:
|
|
1018
999
|
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1019
1000
|
return ""
|
|
1020
|
-
|
|
1001
|
+
|
|
1021
1002
|
def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
|
|
1022
1003
|
if '_modules' not in module.__dict__:
|
|
1023
1004
|
# nothing to hook
|
|
@@ -1151,7 +1132,7 @@ class TrainerMon:
|
|
|
1151
1132
|
context.micro_step = 0
|
|
1152
1133
|
context.step += 1
|
|
1153
1134
|
return
|
|
1154
|
-
|
|
1135
|
+
|
|
1155
1136
|
def stack_hook(module, args, kwargs, module_output, name):
|
|
1156
1137
|
if module not in self.module_fwd_hook_context_by_module:
|
|
1157
1138
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
@@ -1221,7 +1202,7 @@ class TrainerMon:
|
|
|
1221
1202
|
if self.monitor_mbs_grad:
|
|
1222
1203
|
self._hook_weights()
|
|
1223
1204
|
return
|
|
1224
|
-
|
|
1205
|
+
|
|
1225
1206
|
self.optimizer_mon.patch_grad_sync(self)
|
|
1226
1207
|
|
|
1227
1208
|
if self.enable_megatron or self.enable_deepspeed:
|
|
@@ -1281,6 +1262,7 @@ class TrainerMon:
|
|
|
1281
1262
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1282
1263
|
out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
|
|
1283
1264
|
return out
|
|
1265
|
+
|
|
1284
1266
|
return wrapper
|
|
1285
1267
|
|
|
1286
1268
|
logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
|
|
@@ -1294,10 +1276,9 @@ class TrainerMon:
|
|
|
1294
1276
|
"""
|
|
1295
1277
|
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
1296
1278
|
"""
|
|
1297
|
-
context = self.grad_context
|
|
1298
1279
|
|
|
1299
1280
|
@torch.no_grad
|
|
1300
|
-
def param_hook(*args,
|
|
1281
|
+
def param_hook(*args, param, name):
|
|
1301
1282
|
key = name
|
|
1302
1283
|
if self.monitor_mbs_grad:
|
|
1303
1284
|
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
@@ -1305,14 +1286,15 @@ class TrainerMon:
|
|
|
1305
1286
|
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
1306
1287
|
self.register_param_call_id("param_hook", key)
|
|
1307
1288
|
param.micro_step += 1
|
|
1308
|
-
|
|
1289
|
+
grad_dict = {}
|
|
1309
1290
|
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1310
1291
|
if self.params_have_main_grad:
|
|
1311
1292
|
grad = param.main_grad
|
|
1312
1293
|
else:
|
|
1313
1294
|
grad = param.grad
|
|
1314
|
-
|
|
1295
|
+
grad_dict[key] = grad.clone()
|
|
1315
1296
|
|
|
1297
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1316
1298
|
if param.micro_step == self.micro_batch_number:
|
|
1317
1299
|
param.micro_step = 0
|
|
1318
1300
|
|
|
@@ -1322,7 +1304,7 @@ class TrainerMon:
|
|
|
1322
1304
|
param_tmp = param.expand_as(param)
|
|
1323
1305
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1324
1306
|
handle = grad_acc.register_hook(
|
|
1325
|
-
partial(param_hook,
|
|
1307
|
+
partial(param_hook, param=param, name=name))
|
|
1326
1308
|
self.grad_accs.append(grad_acc)
|
|
1327
1309
|
self.handles['wgrads'].append(handle)
|
|
1328
1310
|
|