addftool 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- addftool/addf_portal.py +4 -0
- addftool/broadcast_folder.py +80 -41
- addftool/sleep.py +218 -0
- addftool/ssh/__init__.py +128 -0
- {addftool-0.2.6.dist-info → addftool-0.2.8.dist-info}/METADATA +5 -2
- {addftool-0.2.6.dist-info → addftool-0.2.8.dist-info}/RECORD +9 -8
- {addftool-0.2.6.dist-info → addftool-0.2.8.dist-info}/WHEEL +1 -1
- {addftool-0.2.6.dist-info → addftool-0.2.8.dist-info}/entry_points.txt +0 -0
- {addftool-0.2.6.dist-info → addftool-0.2.8.dist-info}/top_level.txt +0 -0
addftool/addf_portal.py
CHANGED
|
@@ -3,6 +3,7 @@ from addftool.process import add_killer_args, killer_main
|
|
|
3
3
|
from addftool.sync import add_sync_args, sync_main
|
|
4
4
|
from addftool.deploy import add_deploy_args, deploy_main
|
|
5
5
|
from addftool.broadcast_folder import add_broadcast_folder_args, broadcast_folder_main
|
|
6
|
+
from addftool.sleep import add_sleep_args, sleep_main
|
|
6
7
|
|
|
7
8
|
from addftool.blob import add_blob_args, blob_main
|
|
8
9
|
|
|
@@ -16,6 +17,7 @@ def get_args():
|
|
|
16
17
|
add_deploy_args(subparsers)
|
|
17
18
|
add_broadcast_folder_args(subparsers)
|
|
18
19
|
add_blob_args(subparsers)
|
|
20
|
+
add_sleep_args(subparsers)
|
|
19
21
|
|
|
20
22
|
return parser.parse_args()
|
|
21
23
|
|
|
@@ -32,6 +34,8 @@ def main():
|
|
|
32
34
|
broadcast_folder_main(args)
|
|
33
35
|
elif args.command == "blob":
|
|
34
36
|
blob_main(args)
|
|
37
|
+
elif args.command == "sleep":
|
|
38
|
+
sleep_main(args)
|
|
35
39
|
else:
|
|
36
40
|
print("Unknown command: ", args.command)
|
|
37
41
|
|
addftool/broadcast_folder.py
CHANGED
|
@@ -3,13 +3,20 @@ import time
|
|
|
3
3
|
import fnmatch
|
|
4
4
|
import subprocess
|
|
5
5
|
import hashlib
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
try:
|
|
9
|
+
from addftool.ssh import get_client, get_ssh_config, handle_hosts_outputs
|
|
10
|
+
from pssh.clients import ParallelSSHClient
|
|
11
|
+
except ImportError:
|
|
12
|
+
pass
|
|
6
13
|
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
|
|
8
|
-
from fabric import Connection, ThreadingGroup
|
|
14
|
+
import gevent
|
|
9
15
|
|
|
10
16
|
try:
|
|
11
17
|
import torch
|
|
12
18
|
import torch.distributed as dist
|
|
19
|
+
import torch.multiprocessing as mp
|
|
13
20
|
from torch.distributed import init_process_group, destroy_process_group
|
|
14
21
|
_torch_is_available = True
|
|
15
22
|
except ImportError:
|
|
@@ -29,9 +36,8 @@ def add_args(parser):
|
|
|
29
36
|
|
|
30
37
|
parser.add_argument("--md5_verify", action='store_true', default=False,
|
|
31
38
|
help="whether to verify the md5 of the file after broadcast, default is False.")
|
|
32
|
-
parser.add_argument("--port", help="the port for
|
|
33
|
-
parser.add_argument("--
|
|
34
|
-
help="the alias of torchrun, default is torchrun. If you use torchrun, please set it to torchrun.")
|
|
39
|
+
parser.add_argument("--port", help="the port for torch, default is 29501", type=int, default=29501)
|
|
40
|
+
parser.add_argument("--python_alias", type=str, default="python")
|
|
35
41
|
parser.add_argument("--transfer_ranks_per_node", type=int, default=8,
|
|
36
42
|
help="the number of ranks per node to transfer the files, default is 8.")
|
|
37
43
|
|
|
@@ -55,28 +61,13 @@ def add_args(parser):
|
|
|
55
61
|
help="the blob url to download from, default is empty. " \
|
|
56
62
|
"Only node-0 will download the files from the blob url, " \
|
|
57
63
|
"If empty, will transfer the files from the node-0's local folder.")
|
|
64
|
+
|
|
65
|
+
parser.add_argument("--worker_args", type=str, default="")
|
|
58
66
|
|
|
59
67
|
# distributed downloader from blob
|
|
60
68
|
parser.add_argument("folder", help="the folder need to broadcast", type=str)
|
|
61
69
|
|
|
62
70
|
|
|
63
|
-
class ConnectionWithCommand(Connection):
|
|
64
|
-
def __init__(self, host, temp_config_dir, puts, command):
|
|
65
|
-
super().__init__(host)
|
|
66
|
-
self.command = command
|
|
67
|
-
self.puts = puts
|
|
68
|
-
self.temp_config_dir = temp_config_dir
|
|
69
|
-
|
|
70
|
-
def run(self, command, **kwargs):
|
|
71
|
-
super().run(f"mkdir -p {self.temp_config_dir}", **kwargs)
|
|
72
|
-
for src, dest in self.puts:
|
|
73
|
-
self.put(src, remote=dest)
|
|
74
|
-
print(f"Running command on {self.original_host}: {self.command}")
|
|
75
|
-
super().run(self.command, **kwargs)
|
|
76
|
-
if command:
|
|
77
|
-
super().run(command, **kwargs)
|
|
78
|
-
|
|
79
|
-
|
|
80
71
|
def get_ip_via_ssh(hostname):
|
|
81
72
|
if hostname == "localhost":
|
|
82
73
|
return "127.0.0.1"
|
|
@@ -150,20 +141,40 @@ def broadcast_folder_main(args):
|
|
|
150
141
|
host_list.append(line)
|
|
151
142
|
|
|
152
143
|
print(f"Find {len(host_list)} hosts in hostfile: {args.hostfile}")
|
|
153
|
-
connection_list = []
|
|
154
144
|
|
|
155
145
|
remote_temp_config_dir = "/tmp/broadcast_temp_config_dir"
|
|
156
|
-
|
|
157
146
|
master_addr = get_ip_via_ssh(host_list[0])
|
|
147
|
+
|
|
148
|
+
user_ssh_config, systen_ssh_config = get_ssh_config()
|
|
149
|
+
not_successed_hosts = host_list[:]
|
|
150
|
+
for try_count in range(3):
|
|
151
|
+
client = get_client(not_successed_hosts, user_ssh_configs=user_ssh_config, system_ssh_configs=systen_ssh_config)
|
|
152
|
+
mkdir_cmds = client.run_command(f"mkdir -p {remote_temp_config_dir}", stop_on_errors=False)
|
|
153
|
+
client.join(mkdir_cmds, timeout=10)
|
|
154
|
+
assert isinstance(client, ParallelSSHClient), "Failed to create ParallelSSHClient"
|
|
155
|
+
cmds = client.scp_send(__file__, os.path.join(remote_temp_config_dir, "broadcast_folder.py"))
|
|
156
|
+
gevent.joinall(cmds, raise_error=False, timeout=10)
|
|
157
|
+
resend_hosts = []
|
|
158
|
+
for host, cmd in zip(not_successed_hosts, cmds):
|
|
159
|
+
if cmd.exception:
|
|
160
|
+
print(f"Failed to copy file to {host}: {cmd.exception}")
|
|
161
|
+
resend_hosts.append(host)
|
|
162
|
+
|
|
163
|
+
if len(resend_hosts) == 0:
|
|
164
|
+
break
|
|
165
|
+
not_successed_hosts = resend_hosts
|
|
166
|
+
|
|
167
|
+
if len(resend_hosts) > 0:
|
|
168
|
+
print(f"Failed to copy file to {len(resend_hosts)} hosts: {resend_hosts}")
|
|
169
|
+
sys.exit(1)
|
|
158
170
|
|
|
171
|
+
host_commands = []
|
|
159
172
|
for i, host in enumerate(host_list):
|
|
160
|
-
|
|
161
|
-
put_commands.append((__file__, os.path.join(remote_temp_config_dir, "broadcast.py")))
|
|
162
|
-
commnads = "NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 "
|
|
173
|
+
commnads = "PYTHONUNBUFFERED=1 NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 "
|
|
163
174
|
if os.environ.get("SAS_TOKEN") is not None and i == 0:
|
|
164
175
|
commnads += f"SAS_TOKEN=\"{os.environ['SAS_TOKEN']}\" "
|
|
165
|
-
commnads += f"{args.
|
|
166
|
-
commnads += f" {
|
|
176
|
+
commnads += f"{args.python_alias} {remote_temp_config_dir}/broadcast_folder.py {args.folder} --tool {args.tool} --transfer_ranks_per_node {args.transfer_ranks_per_node} "
|
|
177
|
+
commnads += f" --worker_args {master_addr}_{args.port}_{len(host_list)}_{args.transfer_ranks_per_node}_{i} "
|
|
167
178
|
if args.contain_md5_files:
|
|
168
179
|
commnads += " --contain_md5_files"
|
|
169
180
|
if args.include_string:
|
|
@@ -174,11 +185,20 @@ def broadcast_folder_main(args):
|
|
|
174
185
|
commnads += f" --from_blob_url {args.from_blob_url}"
|
|
175
186
|
if args.md5_verify:
|
|
176
187
|
commnads += " --md5_verify"
|
|
177
|
-
|
|
178
|
-
|
|
188
|
+
host_commands.append(commnads)
|
|
189
|
+
print(f"Run command on {host}: {commnads}")
|
|
190
|
+
|
|
191
|
+
client = get_client(host_list, user_ssh_configs=user_ssh_config, system_ssh_configs=systen_ssh_config)
|
|
192
|
+
if True:
|
|
193
|
+
cmds = client.run_command("%s", host_args=host_commands, stop_on_errors=False)
|
|
194
|
+
handle_hosts_outputs(cmds)
|
|
195
|
+
for host, cmd in zip(host_list, cmds):
|
|
196
|
+
if cmd.exception or cmd.exit_code != 0:
|
|
197
|
+
print(f"Failed to run command on {host}: {cmd.exception}, EXIT CODE: {cmd.exit_code}")
|
|
198
|
+
sys.exit(1)
|
|
199
|
+
print(f"Command on {host} finished with exit code {cmd.exit_code}")
|
|
179
200
|
|
|
180
|
-
|
|
181
|
-
group.run('echo "Hello"', hide=False)
|
|
201
|
+
print(f"All nodes finished broadcasting files")
|
|
182
202
|
|
|
183
203
|
|
|
184
204
|
def download_files_from_blob(queue, blob_url, sas_token, folder, download_files, node_rank):
|
|
@@ -259,16 +279,14 @@ def broadcast_file_from_rank(rank, file_path, from_rank, device, file_size, max_
|
|
|
259
279
|
raise ValueError(f"MD5 verification failed for {file_path}: {file_md5} != {src_md5_str}")
|
|
260
280
|
|
|
261
281
|
|
|
262
|
-
def broadcast_folder_worker(args):
|
|
282
|
+
def broadcast_folder_worker(local_rank, node_rank, world_size, master_addr, master_port, args):
|
|
263
283
|
assert args.tool in ["torch_nccl"], f"tool {args.tool} is not supported"
|
|
264
284
|
if not _torch_is_available:
|
|
265
285
|
raise ImportError("Torch is not available. Please install torch to use this feature.")
|
|
266
286
|
start_time = time.time()
|
|
267
287
|
|
|
268
|
-
|
|
269
|
-
global_rank =
|
|
270
|
-
local_rank = int(os.environ['LOCAL_RANK'])
|
|
271
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
|
288
|
+
global_rank = local_rank + node_rank * args.transfer_ranks_per_node
|
|
289
|
+
init_process_group(backend='nccl', init_method=f"tcp://{master_addr}:{master_port}", rank=global_rank, world_size=world_size)
|
|
272
290
|
num_nodes = world_size // args.transfer_ranks_per_node
|
|
273
291
|
worker_rank = local_rank
|
|
274
292
|
|
|
@@ -291,11 +309,15 @@ def broadcast_folder_worker(args):
|
|
|
291
309
|
if global_rank == 0:
|
|
292
310
|
print(f"Init {len(workers_groups)} worker groups")
|
|
293
311
|
|
|
294
|
-
|
|
295
|
-
|
|
312
|
+
if global_rank == 0:
|
|
313
|
+
print(f"rank {global_rank} start broadcast worker, args = {args}, nccl init time: {time.time() - start_time:.2f}s")
|
|
314
|
+
else:
|
|
315
|
+
print(f"rank {global_rank} nccl init time: {time.time() - start_time:.2f}s")
|
|
316
|
+
|
|
296
317
|
file_size_dict = {}
|
|
297
318
|
|
|
298
319
|
if global_rank == 0:
|
|
320
|
+
warnings.filterwarnings("ignore", category=UserWarning, message="The given buffer is not writable, and PyTorch does not support non-writable tensors.")
|
|
299
321
|
# Parse include and exclude patterns
|
|
300
322
|
include_patterns = [p.strip() for p in args.include_string.split(";") if p.strip()]
|
|
301
323
|
exclude_patterns = [p.strip() for p in args.exclude_string.split(";") if p.strip()]
|
|
@@ -397,6 +419,23 @@ def broadcast_folder_worker(args):
|
|
|
397
419
|
print(f"Rank {global_rank} finished broadcasting all files, time taken: {time.time() - start_time:.2f}s")
|
|
398
420
|
|
|
399
421
|
|
|
422
|
+
def broadcast_node_main(args):
|
|
423
|
+
if not _torch_is_available:
|
|
424
|
+
raise ImportError("Torch is not available. Please install torch to use this feature.")
|
|
425
|
+
parts = args.worker_args.split("_")
|
|
426
|
+
master_addr = parts[0]
|
|
427
|
+
master_port = int(parts[1])
|
|
428
|
+
num_nodes = int(parts[2])
|
|
429
|
+
num_ranks_per_node = int(parts[3])
|
|
430
|
+
node_rank = int(parts[4])
|
|
431
|
+
world_size = num_nodes * num_ranks_per_node
|
|
432
|
+
|
|
433
|
+
mp.spawn(
|
|
434
|
+
broadcast_folder_worker, nprocs=num_ranks_per_node, join=True,
|
|
435
|
+
args=(node_rank, world_size, master_addr, master_port, args),
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
|
|
400
439
|
if __name__ == "__main__":
|
|
401
440
|
import argparse
|
|
402
441
|
|
|
@@ -406,4 +445,4 @@ if __name__ == "__main__":
|
|
|
406
445
|
if args.hostfile:
|
|
407
446
|
broadcast_folder_main(args)
|
|
408
447
|
else:
|
|
409
|
-
|
|
448
|
+
broadcast_node_main(args)
|
addftool/sleep.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
import multiprocessing as mp
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import torch
|
|
8
|
+
except ImportError:
|
|
9
|
+
print("PyTorch is not installed. Please install it to run this script.")
|
|
10
|
+
sys.exit(1)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_gpu_stats(device_id):
|
|
14
|
+
"""获取指定GPU的利用率和显存使用情况"""
|
|
15
|
+
try:
|
|
16
|
+
cmd = f"nvidia-smi --id={device_id} --query-gpu=utilization.gpu,memory.used --format=csv,noheader,nounits"
|
|
17
|
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
18
|
+
|
|
19
|
+
if result.returncode != 0:
|
|
20
|
+
print(f"Error running nvidia-smi for GPU {device_id}")
|
|
21
|
+
return None, None
|
|
22
|
+
|
|
23
|
+
# 解析输出
|
|
24
|
+
output = result.stdout.strip()
|
|
25
|
+
if output:
|
|
26
|
+
parts = output.split(',')
|
|
27
|
+
if len(parts) == 2:
|
|
28
|
+
gpu_util = int(parts[0]) # GPU利用率百分比
|
|
29
|
+
memory_used = int(parts[1]) # 显存使用量(MB)
|
|
30
|
+
return gpu_util, memory_used
|
|
31
|
+
|
|
32
|
+
return None, None
|
|
33
|
+
except Exception as e:
|
|
34
|
+
print(f"Error getting GPU stats for device {device_id}: {e}")
|
|
35
|
+
return None, None
|
|
36
|
+
|
|
37
|
+
def check_gpu_occupied(device_id, util_threshold=20, memory_threshold=2048):
|
|
38
|
+
"""检查GPU是否被其他进程占用
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
device_id: GPU设备ID
|
|
42
|
+
util_threshold: GPU利用率阈值(默认20%)
|
|
43
|
+
memory_threshold: 显存使用阈值(默认2048MB = 2GB)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
bool: True表示GPU被占用,False表示GPU空闲
|
|
47
|
+
"""
|
|
48
|
+
gpu_util, memory_used = get_gpu_stats(device_id)
|
|
49
|
+
|
|
50
|
+
if gpu_util is None or memory_used is None:
|
|
51
|
+
# 获取失败时保守处理
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
# 判断是否被占用
|
|
55
|
+
is_occupied = gpu_util > util_threshold or (memory_threshold > 0 and memory_used > memory_threshold)
|
|
56
|
+
|
|
57
|
+
if is_occupied:
|
|
58
|
+
print(f"GPU {device_id}: Util={gpu_util}%, Memory={memory_used}MB - Occupied")
|
|
59
|
+
|
|
60
|
+
return is_occupied
|
|
61
|
+
|
|
62
|
+
def check_all_gpus(num_gpus, util_threshold=20, memory_threshold=-1):
|
|
63
|
+
"""检查所有GPU是否被占用"""
|
|
64
|
+
for device_id in range(num_gpus):
|
|
65
|
+
if check_gpu_occupied(device_id, util_threshold, memory_threshold):
|
|
66
|
+
return True, device_id
|
|
67
|
+
return False, -1
|
|
68
|
+
|
|
69
|
+
def get_all_gpu_status(num_gpus):
|
|
70
|
+
"""获取所有GPU的状态信息"""
|
|
71
|
+
print("\nGPU Status:")
|
|
72
|
+
print("-" * 50)
|
|
73
|
+
for device_id in range(num_gpus):
|
|
74
|
+
gpu_util, memory_used = get_gpu_stats(device_id)
|
|
75
|
+
if gpu_util is not None and memory_used is not None:
|
|
76
|
+
status = "Available" if (gpu_util <= 20 and memory_used <= 2048) else "Occupied"
|
|
77
|
+
print(f"GPU {device_id}: Util={gpu_util:3d}%, Memory={memory_used:5d}MB - {status}")
|
|
78
|
+
else:
|
|
79
|
+
print(f"GPU {device_id}: Unable to get stats")
|
|
80
|
+
print("-" * 50)
|
|
81
|
+
|
|
82
|
+
def matrix_multiply_worker(matrix_size=8192, time_duration=4.0, sleep_duration=1.0, util_threshold=20, memory_threshold=-1):
|
|
83
|
+
# 获取GPU数量
|
|
84
|
+
num_gpus = torch.cuda.device_count()
|
|
85
|
+
if num_gpus == 0:
|
|
86
|
+
print("No GPUs available!")
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
matrices = {}
|
|
90
|
+
# print(f"Creating {matrix_size}x{matrix_size} matrices on all GPUs...")
|
|
91
|
+
for device_id in range(num_gpus):
|
|
92
|
+
device = torch.device(f'cuda:{device_id}')
|
|
93
|
+
matrices[device_id] = {
|
|
94
|
+
'a': torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32),
|
|
95
|
+
'b': torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32)
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# 主循环
|
|
99
|
+
while True:
|
|
100
|
+
try:
|
|
101
|
+
# 检查所有GPU是否被占用
|
|
102
|
+
has_occupied_gpu, occupied_gpu = check_all_gpus(num_gpus, util_threshold, memory_threshold)
|
|
103
|
+
if has_occupied_gpu:
|
|
104
|
+
break
|
|
105
|
+
|
|
106
|
+
start_time = time.time()
|
|
107
|
+
perform_count = 0
|
|
108
|
+
while True:
|
|
109
|
+
# 在所有GPU上同时执行矩阵乘法
|
|
110
|
+
results = {}
|
|
111
|
+
for device_id in range(num_gpus):
|
|
112
|
+
results[device_id] = torch.matmul(matrices[device_id]['a'], matrices[device_id]['b'])
|
|
113
|
+
|
|
114
|
+
perform_count += 1
|
|
115
|
+
|
|
116
|
+
if perform_count % 10 == 0:
|
|
117
|
+
for device_id in range(num_gpus):
|
|
118
|
+
torch.cuda.synchronize(device_id)
|
|
119
|
+
|
|
120
|
+
torch.cuda.synchronize() # 确保所有GPU操作完成
|
|
121
|
+
elapsed_time = time.time() - start_time
|
|
122
|
+
if elapsed_time > time_duration:
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
# 清理内存
|
|
126
|
+
|
|
127
|
+
time.sleep(sleep_duration)
|
|
128
|
+
|
|
129
|
+
except KeyboardInterrupt:
|
|
130
|
+
print("\nKeyboardInterrupt received, stopping...")
|
|
131
|
+
stop_flag = True
|
|
132
|
+
exit(0)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f"\nError occurred: {e}")
|
|
135
|
+
# 尝试清理内存
|
|
136
|
+
try:
|
|
137
|
+
for device_id in range(num_gpus):
|
|
138
|
+
torch.cuda.set_device(device_id)
|
|
139
|
+
torch.cuda.empty_cache()
|
|
140
|
+
except:
|
|
141
|
+
pass
|
|
142
|
+
time.sleep(5)
|
|
143
|
+
|
|
144
|
+
def sleep_main(args):
|
|
145
|
+
# 设置多进程启动方法
|
|
146
|
+
mp.set_start_method('spawn', force=True)
|
|
147
|
+
|
|
148
|
+
num_gpus = torch.cuda.device_count()
|
|
149
|
+
if num_gpus == 0:
|
|
150
|
+
print("No GPUs available!")
|
|
151
|
+
exit(1)
|
|
152
|
+
|
|
153
|
+
# 显示初始GPU状态
|
|
154
|
+
get_all_gpu_status(num_gpus)
|
|
155
|
+
|
|
156
|
+
current_process = None
|
|
157
|
+
|
|
158
|
+
# 主循环
|
|
159
|
+
while True:
|
|
160
|
+
try:
|
|
161
|
+
# 检查所有GPU是否被占用
|
|
162
|
+
has_occupied_gpu, occupied_gpu = check_all_gpus(num_gpus, util_threshold=args.util_threshold, memory_threshold=args.memory_threshold)
|
|
163
|
+
|
|
164
|
+
if has_occupied_gpu:
|
|
165
|
+
# 休眠60秒
|
|
166
|
+
print("Holding for 60 seconds...")
|
|
167
|
+
time.sleep(60)
|
|
168
|
+
|
|
169
|
+
else:
|
|
170
|
+
# GPU空闲,启动矩阵乘法进程
|
|
171
|
+
current_process = mp.Process(
|
|
172
|
+
target=matrix_multiply_worker,
|
|
173
|
+
args=(args.matrix_size, args.time_duration, args.sleep_duration, args.util_threshold, args.memory_threshold),
|
|
174
|
+
)
|
|
175
|
+
current_process.start()
|
|
176
|
+
current_process.join()
|
|
177
|
+
|
|
178
|
+
except KeyboardInterrupt:
|
|
179
|
+
print("\nKeyboardInterrupt received, stopping...")
|
|
180
|
+
stop_flag = True
|
|
181
|
+
break
|
|
182
|
+
except Exception as e:
|
|
183
|
+
print(f"\nError occurred: {e}")
|
|
184
|
+
time.sleep(5)
|
|
185
|
+
|
|
186
|
+
print("\nProgram stopped")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def add_sleep_args(subparsers):
|
|
190
|
+
sleep_parser = subparsers.add_parser('sleep', help='Sleep for a while and check GPU status')
|
|
191
|
+
add_args(sleep_parser)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def add_args(parser):
|
|
195
|
+
parser.add_argument('--matrix_size', type=int, default=8192, help='Size of the matrices to multiply')
|
|
196
|
+
parser.add_argument('--time_duration', type=float, default=4.0, help='Duration to perform matrix multiplication')
|
|
197
|
+
parser.add_argument('--sleep_duration', type=float, default=1.0, help='Duration to sleep between checks')
|
|
198
|
+
|
|
199
|
+
parser.add_argument('--util_threshold', type=int, default=20, help='GPU utilization threshold to consider it occupied')
|
|
200
|
+
parser.add_argument('--memory_threshold', type=int, default=-1, help='Memory usage threshold (in GB) to consider it occupied, set to -1 to disable')
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
if __name__ == "__main__":
|
|
204
|
+
import argparse
|
|
205
|
+
parser = argparse.ArgumentParser(description='Sleep and check GPU status')
|
|
206
|
+
add_args(parser)
|
|
207
|
+
args = parser.parse_args()
|
|
208
|
+
while True:
|
|
209
|
+
try:
|
|
210
|
+
sleep_main(args)
|
|
211
|
+
except KeyboardInterrupt:
|
|
212
|
+
print("\nKeyboardInterrupt received, exiting...")
|
|
213
|
+
sys.exit(0)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print(f"\nUnexpected error: {e}")
|
|
216
|
+
print("Restarting the program in 5 seconds...")
|
|
217
|
+
time.sleep(5)
|
|
218
|
+
continue
|
addftool/ssh/__init__.py
CHANGED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
from pssh.clients import ParallelSSHClient
|
|
4
|
+
from pssh.config import HostConfig
|
|
5
|
+
import gevent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_host_config(hostname, configs):
|
|
9
|
+
"""Get configuration for a specific host from all configs."""
|
|
10
|
+
host_config = {}
|
|
11
|
+
|
|
12
|
+
# Check for exact hostname match
|
|
13
|
+
if hostname in configs:
|
|
14
|
+
host_config.update(configs[hostname])
|
|
15
|
+
|
|
16
|
+
# Check for wildcard matches
|
|
17
|
+
for pattern, config in configs.items():
|
|
18
|
+
if '*' in pattern or '?' in pattern:
|
|
19
|
+
# Convert SSH glob pattern to regex pattern
|
|
20
|
+
regex_pattern = pattern.replace('.', '\\.').replace('*', '.*').replace('?', '.')
|
|
21
|
+
if re.match(f"^{regex_pattern}$", hostname):
|
|
22
|
+
host_config.update(config)
|
|
23
|
+
|
|
24
|
+
return host_config
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def parse_ssh_config_file(file_path):
|
|
28
|
+
"""Parse SSH config file into a dictionary of host configurations."""
|
|
29
|
+
host_configs = {}
|
|
30
|
+
current_host = None
|
|
31
|
+
|
|
32
|
+
if not os.path.exists(file_path):
|
|
33
|
+
return host_configs
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
with open(file_path, 'r') as f:
|
|
37
|
+
for line in f:
|
|
38
|
+
line = line.strip()
|
|
39
|
+
if not line or line.startswith('#'):
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
if line.lower().startswith('host ') and not line.lower().startswith('host *'):
|
|
43
|
+
hosts = line.split()[1:]
|
|
44
|
+
for host in hosts:
|
|
45
|
+
current_host = host
|
|
46
|
+
if current_host not in host_configs:
|
|
47
|
+
host_configs[current_host] = {}
|
|
48
|
+
elif current_host and ' ' in line:
|
|
49
|
+
key, value = line.split(None, 1)
|
|
50
|
+
host_configs[current_host][key.lower()] = value
|
|
51
|
+
except Exception as e:
|
|
52
|
+
print(f"Error reading {file_path}: {e}")
|
|
53
|
+
|
|
54
|
+
return host_configs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_ssh_config():
|
|
58
|
+
user_config_path = os.path.expanduser("~/.ssh/config")
|
|
59
|
+
system_config_path = "/etc/ssh/ssh_config"
|
|
60
|
+
|
|
61
|
+
user_configs = parse_ssh_config_file(user_config_path)
|
|
62
|
+
system_configs = parse_ssh_config_file(system_config_path)
|
|
63
|
+
return user_configs, system_configs
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_client(hosts, user_ssh_configs=None, system_ssh_configs=None):
|
|
67
|
+
if user_ssh_configs is None or system_ssh_configs is None:
|
|
68
|
+
user_ssh_configs, system_ssh_configs = get_ssh_config()
|
|
69
|
+
|
|
70
|
+
to_connect = []
|
|
71
|
+
for hostname in hosts:
|
|
72
|
+
host_ssh_config = get_host_config(hostname, user_ssh_configs)
|
|
73
|
+
if not host_ssh_config:
|
|
74
|
+
host_ssh_config = get_host_config(hostname, system_ssh_configs)
|
|
75
|
+
|
|
76
|
+
if host_ssh_config:
|
|
77
|
+
host_config = HostConfig(
|
|
78
|
+
user=host_ssh_config.get('user', os.getenv('USER', 'root')),
|
|
79
|
+
port=int(host_ssh_config.get('port', 22)),
|
|
80
|
+
private_key=host_ssh_config.get('identityfile', None),
|
|
81
|
+
)
|
|
82
|
+
host_name = host_ssh_config.get('hostname', hostname)
|
|
83
|
+
to_connect.append((host_name, host_config))
|
|
84
|
+
else:
|
|
85
|
+
print(f"No config found for host {hostname}")
|
|
86
|
+
|
|
87
|
+
# Create a ParallelSSHClient with the list of hosts
|
|
88
|
+
client = ParallelSSHClient([host[0] for host in to_connect], host_config=[host[1] for host in to_connect])
|
|
89
|
+
return client
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def handle_stream(host, stream_in, stream_name, print_call=None):
|
|
93
|
+
if print_call is None:
|
|
94
|
+
print_call = print
|
|
95
|
+
try:
|
|
96
|
+
if stream_in:
|
|
97
|
+
for line in stream_in:
|
|
98
|
+
prefix = " ERROR" if stream_name == "stderr" else ""
|
|
99
|
+
print_call(f"[{host}]{prefix}: {line}")
|
|
100
|
+
gevent.sleep(0) # 让出控制权
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"[{host}] {stream_name} Exception: {e}")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _stdout_log(line):
|
|
106
|
+
"""Log the output line."""
|
|
107
|
+
print(line)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _stderr_log(line):
|
|
111
|
+
"""Log the error line."""
|
|
112
|
+
print("ERROR: " + line)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def handle_hosts_outputs(hosts_outputs, out_log=None, err_log=None):
|
|
116
|
+
"""Handle the outputs from the SSH command execution."""
|
|
117
|
+
if out_log is None:
|
|
118
|
+
out_log = _stdout_log
|
|
119
|
+
if err_log is None:
|
|
120
|
+
err_log = _stderr_log
|
|
121
|
+
jobs = []
|
|
122
|
+
for output in hosts_outputs:
|
|
123
|
+
host_name = output.host
|
|
124
|
+
if output:
|
|
125
|
+
jobs.append(gevent.spawn(handle_stream, host_name, output.stdout, "stdout", out_log))
|
|
126
|
+
jobs.append(gevent.spawn(handle_stream, host_name, output.stderr, "stderr", err_log))
|
|
127
|
+
|
|
128
|
+
gevent.joinall(jobs, raise_error=False)
|
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
2
|
Name: addftool
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Requires-Dist: cryptography
|
|
5
5
|
Requires-Dist: requests
|
|
6
6
|
Requires-Dist: PyYAML
|
|
7
7
|
Requires-Dist: psutil
|
|
8
8
|
Requires-Dist: fabric
|
|
9
|
+
Requires-Dist: gevent
|
|
10
|
+
Requires-Dist: parallel-ssh
|
|
11
|
+
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
addftool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
addftool/addf_portal.py,sha256=
|
|
2
|
+
addftool/addf_portal.py,sha256=vc8opPzValNFPwJne5C5LbZvgcJ0eMBJSWDSiM23OPM,1274
|
|
3
3
|
addftool/blob.py,sha256=y1HZaDBUNeXicVytvwpRXwufvvrgxR33ruBlYpxnSa4,9453
|
|
4
|
-
addftool/broadcast_folder.py,sha256=
|
|
4
|
+
addftool/broadcast_folder.py,sha256=GQBuSL8Ch537V_fSBHesWyqT3KRYry68pbYOKy2bDj4,19619
|
|
5
|
+
addftool/sleep.py,sha256=FA1fTUI47eQq-9nBtXElkS7SZMunP_5tLiIBuFNSM6w,7823
|
|
5
6
|
addftool/sync.py,sha256=ZpYxbM8uiPFrV7ODmOaM7asVPCWaxBixA-arVc-1kfs,14045
|
|
6
7
|
addftool/tool.py,sha256=FmxRY3-pP0_Z0zCUAngjmEMmPUruMftg_iUlB1t2TnQ,2001
|
|
7
8
|
addftool/util.py,sha256=zlNLu8Be8cGIpNRqBw8_0q7nFxWlsJ9cToN62ohjdXE,2335
|
|
@@ -11,9 +12,9 @@ addftool/deploy/ssh_server.py,sha256=7glpJJNROskpqPkeYrTc2MbVzRendUZLv-ZgPs6HCq8
|
|
|
11
12
|
addftool/deploy/vscode_server.py,sha256=tLtSvlcK2fEOaw6udWt8dNELVhwv9F59hF5DJJ-1Nak,2666
|
|
12
13
|
addftool/process/__init__.py,sha256=Dze8OrcyjQlAbPrjE_h8bMi8W4b3OJyZOjTucPrkJvM,3721
|
|
13
14
|
addftool/process/utils.py,sha256=JldxnwanLJOgxaPgmCJh7SeBRaaj5rFxWWxh1hpsvbA,2609
|
|
14
|
-
addftool/ssh/__init__.py,sha256=
|
|
15
|
-
addftool-0.2.
|
|
16
|
-
addftool-0.2.
|
|
17
|
-
addftool-0.2.
|
|
18
|
-
addftool-0.2.
|
|
19
|
-
addftool-0.2.
|
|
15
|
+
addftool/ssh/__init__.py,sha256=h5_rCO0A6q2Yw9vFguQZZp_ApAJsT1dcnKnbKKZ0cDM,4409
|
|
16
|
+
addftool-0.2.8.dist-info/METADATA,sha256=rxu5Oy4lH7lQF99Z8gzz5QuoGxnZ739h0OBNhr_0NA0,221
|
|
17
|
+
addftool-0.2.8.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
18
|
+
addftool-0.2.8.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
|
|
19
|
+
addftool-0.2.8.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
|
|
20
|
+
addftool-0.2.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|