mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
- msprobe/README.md +7 -5
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +105 -27
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +2 -1
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +37 -3
- msprobe/core/service.py +18 -5
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +3 -3
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +170 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- msprobe/visualization/utils.py +11 -1
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
|
@@ -51,6 +51,8 @@ from msprobe.pytorch.pt_config import parse_json_config
|
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
52
|
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
54
56
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
55
57
|
ExecParams
|
|
56
58
|
|
|
@@ -88,22 +90,27 @@ seed_all()
|
|
|
88
90
|
|
|
89
91
|
def run_ut(config):
|
|
90
92
|
logger.info("start UT test")
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
if config.online_config.is_online:
|
|
94
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
|
|
95
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
|
|
96
|
+
else:
|
|
97
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
98
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
94
99
|
|
|
95
100
|
if config.save_error_data:
|
|
96
101
|
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
97
102
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
98
103
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
104
|
+
if config.online_config.is_online:
|
|
105
|
+
run_api_online(config, compare)
|
|
106
|
+
else:
|
|
107
|
+
csv_df = read_csv(config.result_csv_path)
|
|
108
|
+
try:
|
|
109
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
110
|
+
except IndexError:
|
|
111
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
112
|
+
api_name_set = set()
|
|
113
|
+
run_api_offline(config, compare, api_name_set)
|
|
107
114
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
108
115
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
109
116
|
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -157,6 +164,60 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
157
164
|
gc.collect()
|
|
158
165
|
|
|
159
166
|
|
|
167
|
+
def run_api_online(config, compare):
|
|
168
|
+
attl = init_attl(config.online_config)
|
|
169
|
+
dispatcher = ConsumerDispatcher(compare=compare)
|
|
170
|
+
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
171
|
+
|
|
172
|
+
def tcp_communication_flow():
|
|
173
|
+
while True:
|
|
174
|
+
api_data = attl.recv()
|
|
175
|
+
if api_data == 'STOP_':
|
|
176
|
+
continue
|
|
177
|
+
if api_data == 'KILL_':
|
|
178
|
+
time.sleep(1)
|
|
179
|
+
logger.info("==========接收到STOP信号==========")
|
|
180
|
+
dispatcher.stop()
|
|
181
|
+
attl.stop_serve()
|
|
182
|
+
time.sleep(1)
|
|
183
|
+
break
|
|
184
|
+
if not isinstance(api_data, ApiData):
|
|
185
|
+
continue
|
|
186
|
+
api_full_name = api_data.name
|
|
187
|
+
_, api_name = extract_basic_api_segments(api_full_name)
|
|
188
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
189
|
+
continue
|
|
190
|
+
if api_data.rank in config.online_config.rank_list:
|
|
191
|
+
dispatcher.update_consume_queue(api_data)
|
|
192
|
+
|
|
193
|
+
def shared_storage_communication_flow():
|
|
194
|
+
flag_num = -1
|
|
195
|
+
while True:
|
|
196
|
+
api_data = attl.download()
|
|
197
|
+
if api_data == "start":
|
|
198
|
+
if flag_num == -1:
|
|
199
|
+
flag_num += 1
|
|
200
|
+
flag_num += 1
|
|
201
|
+
if api_data == "end":
|
|
202
|
+
flag_num -= 1
|
|
203
|
+
if flag_num == 0:
|
|
204
|
+
dispatcher.stop()
|
|
205
|
+
break
|
|
206
|
+
if not isinstance(api_data, ApiData):
|
|
207
|
+
continue
|
|
208
|
+
api_full_name = api_data.name
|
|
209
|
+
_, api_name = extract_basic_api_segments(api_full_name)
|
|
210
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
211
|
+
continue
|
|
212
|
+
if api_data.rank in config.online_config.rank_list:
|
|
213
|
+
dispatcher.update_consume_queue(api_data)
|
|
214
|
+
|
|
215
|
+
if config.online_config.nfs_path:
|
|
216
|
+
shared_storage_communication_flow()
|
|
217
|
+
else:
|
|
218
|
+
tcp_communication_flow()
|
|
219
|
+
|
|
220
|
+
|
|
160
221
|
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
161
222
|
"""
|
|
162
223
|
run api(api_name) if api_name not in black_list and in white_list.
|
|
@@ -254,6 +315,21 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
254
315
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
255
316
|
|
|
256
317
|
|
|
318
|
+
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
319
|
+
in_fwd_data_list = []
|
|
320
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
321
|
+
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
322
|
+
in_fwd_data_list.append(args)
|
|
323
|
+
in_fwd_data_list.append(kwargs)
|
|
324
|
+
if kwargs.get("device"):
|
|
325
|
+
del kwargs["device"]
|
|
326
|
+
|
|
327
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
328
|
+
device_out = exec_api(device_exec_params)
|
|
329
|
+
device_out = move2device_exec(device_out, "cpu")
|
|
330
|
+
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
331
|
+
|
|
332
|
+
|
|
257
333
|
def check_need_grad(api_info_dict):
|
|
258
334
|
need_grad = True
|
|
259
335
|
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
@@ -313,6 +389,16 @@ def initialize_save_error_data(error_data_path):
|
|
|
313
389
|
return error_data_path
|
|
314
390
|
|
|
315
391
|
|
|
392
|
+
def init_attl(config):
|
|
393
|
+
"""config: OnlineConfig"""
|
|
394
|
+
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
395
|
+
connect_ip=config.host,
|
|
396
|
+
connect_port=config.port,
|
|
397
|
+
nfs_path=config.nfs_path,
|
|
398
|
+
tls_path=config.tls_path))
|
|
399
|
+
return attl
|
|
400
|
+
|
|
401
|
+
|
|
316
402
|
def _run_ut_parser(parser):
|
|
317
403
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
318
404
|
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
@@ -395,6 +481,38 @@ def _run_ut(parser=None):
|
|
|
395
481
|
_run_ut_parser(parser)
|
|
396
482
|
args = parser.parse_args(sys.argv[1:])
|
|
397
483
|
run_ut_command(args)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def checked_online_config(online_config):
|
|
487
|
+
if not online_config.is_online:
|
|
488
|
+
return
|
|
489
|
+
if not isinstance(online_config.is_online, bool):
|
|
490
|
+
raise ValueError("is_online must be bool type")
|
|
491
|
+
# rank_list
|
|
492
|
+
if not isinstance(online_config.rank_list, list):
|
|
493
|
+
raise ValueError("rank_list must be a list")
|
|
494
|
+
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
495
|
+
raise ValueError("All elements in rank_list must be integers")
|
|
496
|
+
|
|
497
|
+
# nfs_path
|
|
498
|
+
if online_config.nfs_path:
|
|
499
|
+
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
500
|
+
return
|
|
501
|
+
# tls_path
|
|
502
|
+
if online_config.tls_path:
|
|
503
|
+
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
504
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
505
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
506
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
+
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
+
if os.path.exists(crl_path):
|
|
509
|
+
check_file_or_directory_path(crl_path)
|
|
510
|
+
|
|
511
|
+
# host and port
|
|
512
|
+
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
513
|
+
raise Exception(f"host: {online_config.host} is invalid.")
|
|
514
|
+
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
515
|
+
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
398
516
|
|
|
399
517
|
|
|
400
518
|
def run_ut_command(args):
|
|
@@ -407,7 +525,7 @@ def run_ut_command(args):
|
|
|
407
525
|
else:
|
|
408
526
|
checker_config = CheckerConfig()
|
|
409
527
|
|
|
410
|
-
if not args.api_info_file:
|
|
528
|
+
if not checker_config.is_online and not args.api_info_file:
|
|
411
529
|
logger.error("Please provide api_info_file for offline run ut.")
|
|
412
530
|
raise Exception("Please provide api_info_file for offline run ut.")
|
|
413
531
|
|
|
@@ -470,6 +588,8 @@ def run_ut_command(args):
|
|
|
470
588
|
global UT_ERROR_DATA_DIR
|
|
471
589
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
472
590
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
591
|
+
online_config = checker_config.get_online_config()
|
|
592
|
+
checked_online_config(online_config)
|
|
473
593
|
config_params = {
|
|
474
594
|
'forward_content': forward_content,
|
|
475
595
|
'backward_content': backward_content,
|
|
File without changes
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import glob
|
|
17
|
+
import os.path
|
|
18
|
+
import time
|
|
19
|
+
from multiprocessing import Queue
|
|
20
|
+
from typing import Optional, Union, Dict, Any
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
28
|
+
from msprobe.core.common.file_utils import remove_path
|
|
29
|
+
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
|
+
|
|
32
|
+
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ATTLConfig:
|
|
37
|
+
is_benchmark_device: bool
|
|
38
|
+
connect_ip: str
|
|
39
|
+
connect_port: int
|
|
40
|
+
# storage_config
|
|
41
|
+
nfs_path: str = None
|
|
42
|
+
tls_path: str = None
|
|
43
|
+
check_sum: bool = True
|
|
44
|
+
queue_size: int = 50
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ATTL:
|
|
48
|
+
def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
|
|
49
|
+
self.session_id = session_id
|
|
50
|
+
self.session_config = session_config
|
|
51
|
+
self.logger = logger
|
|
52
|
+
self.socket_manager = None
|
|
53
|
+
self.data_queue = Queue(maxsize=50)
|
|
54
|
+
self.dequeue_list = []
|
|
55
|
+
self.message_end = False
|
|
56
|
+
self.kill_progress = False
|
|
57
|
+
self.nfs_path = None
|
|
58
|
+
if self.session_config.nfs_path:
|
|
59
|
+
self.nfs_path = self.session_config.nfs_path
|
|
60
|
+
elif self.session_config.is_benchmark_device:
|
|
61
|
+
|
|
62
|
+
self.socket_manager = TCPServer(self.session_config.connect_port,
|
|
63
|
+
self.data_queue,
|
|
64
|
+
self.session_config.check_sum,
|
|
65
|
+
self.session_config.tls_path)
|
|
66
|
+
self.socket_manager.start()
|
|
67
|
+
elif need_dump:
|
|
68
|
+
self.socket_manager = TCPClient(self.session_config.connect_ip,
|
|
69
|
+
self.session_config.connect_port,
|
|
70
|
+
self.session_config.check_sum,
|
|
71
|
+
self.session_config.tls_path)
|
|
72
|
+
self.socket_manager.start()
|
|
73
|
+
|
|
74
|
+
def stop_serve(self):
|
|
75
|
+
if isinstance(self.socket_manager, TCPServer):
|
|
76
|
+
self.socket_manager.stop()
|
|
77
|
+
|
|
78
|
+
def send(self, buffer: BufferType) -> None:
|
|
79
|
+
"""
|
|
80
|
+
npu major in 'send' (client)
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
# if tcp connection lost,
|
|
84
|
+
if self.socket_manager.signal_exit:
|
|
85
|
+
raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
|
|
86
|
+
|
|
87
|
+
# know receiver receive and go next
|
|
88
|
+
if isinstance(buffer, ApiData):
|
|
89
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
90
|
+
|
|
91
|
+
if 'device' in buffer.kwargs:
|
|
92
|
+
buffer.kwargs.pop('device')
|
|
93
|
+
rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
|
|
94
|
+
step = buffer.step if hasattr(buffer, "step") else 0
|
|
95
|
+
try:
|
|
96
|
+
io_buff = save_api_data(buffer)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
|
|
99
|
+
return
|
|
100
|
+
data = io_buff.getvalue()
|
|
101
|
+
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
102
|
+
|
|
103
|
+
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
104
|
+
buffer = ''
|
|
105
|
+
while not buffer:
|
|
106
|
+
if timeout_ms > 0:
|
|
107
|
+
time.sleep(timeout_ms / 1000.0)
|
|
108
|
+
if not buffer and not self.data_queue.empty():
|
|
109
|
+
buffer = self.data_queue.get()
|
|
110
|
+
break
|
|
111
|
+
if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
112
|
+
break
|
|
113
|
+
if self.message_end and self.data_queue.empty():
|
|
114
|
+
buffer = b"KILL_CONFIRM"
|
|
115
|
+
self.kill_progress = True
|
|
116
|
+
break
|
|
117
|
+
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
118
|
+
if not buffer:
|
|
119
|
+
# this is a result of a timeout
|
|
120
|
+
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
121
|
+
else:
|
|
122
|
+
if buffer == b"STOP_":
|
|
123
|
+
return "STOP_"
|
|
124
|
+
if buffer == b"KILL_":
|
|
125
|
+
self.message_end = True
|
|
126
|
+
return "STOP_"
|
|
127
|
+
if buffer == b"KILL_CONFIRM":
|
|
128
|
+
self.kill_progress = True
|
|
129
|
+
return "KILL_"
|
|
130
|
+
try:
|
|
131
|
+
buffer = load_api_data(buffer)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
134
|
+
if isinstance(buffer, bytes):
|
|
135
|
+
return ''
|
|
136
|
+
if isinstance(buffer, str):
|
|
137
|
+
return buffer
|
|
138
|
+
|
|
139
|
+
return buffer
|
|
140
|
+
|
|
141
|
+
def upload(self, buffer: BufferType):
|
|
142
|
+
if isinstance(buffer, ApiData):
|
|
143
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
144
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
|
|
145
|
+
else:
|
|
146
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
save_pkl(buffer, file_path)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
152
|
+
|
|
153
|
+
def download(self):
|
|
154
|
+
buffer = None
|
|
155
|
+
cur_file = None
|
|
156
|
+
for file_type in ("start*", "*.pt", "end*"):
|
|
157
|
+
pattern = os.path.join(self.nfs_path, file_type)
|
|
158
|
+
files = glob.glob(pattern)
|
|
159
|
+
if len(files) > 0:
|
|
160
|
+
cur_file = files[0]
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
if cur_file is not None:
|
|
164
|
+
try:
|
|
165
|
+
buffer = load_pkl(cur_file)
|
|
166
|
+
except Exception as e:
|
|
167
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
168
|
+
remove_path(cur_file)
|
|
169
|
+
return buffer
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@recursion_depth_decorator("move2device_exec")
|
|
173
|
+
def move2device_exec(obj, device):
|
|
174
|
+
if isinstance(obj, (tuple, list)):
|
|
175
|
+
data_list = [move2device_exec(val, device) for val in obj]
|
|
176
|
+
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
177
|
+
if isinstance(obj, dict):
|
|
178
|
+
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
179
|
+
elif isinstance(obj, torch.Tensor):
|
|
180
|
+
obj = obj.detach()
|
|
181
|
+
if obj.device.type != device:
|
|
182
|
+
obj = obj.to(device)
|
|
183
|
+
return obj
|
|
184
|
+
elif "return_types" in str(type(obj)):
|
|
185
|
+
return move2device_exec(tuple(obj), device)
|
|
186
|
+
elif isinstance(obj, torch._C.device):
|
|
187
|
+
return torch.device(device)
|
|
188
|
+
else:
|
|
189
|
+
return obj
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def move2target_device(buffer: ApiData, target_device):
|
|
193
|
+
# handle args
|
|
194
|
+
new_args = move2device_exec(buffer.args, target_device)
|
|
195
|
+
|
|
196
|
+
# handle kwargs
|
|
197
|
+
new_kwargs = move2device_exec(buffer.kwargs, target_device)
|
|
198
|
+
|
|
199
|
+
# handle result
|
|
200
|
+
new_results = move2device_exec(buffer.result, target_device)
|
|
201
|
+
|
|
202
|
+
if target_device == torch.device('cpu') or target_device == "cpu":
|
|
203
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
|
|
204
|
+
else:
|
|
205
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
|