ladyrick 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ladyrick/allgather.py +106 -0
- ladyrick/cli/multi_ssh.py +38 -30
- ladyrick/cli/psf.py +2 -2
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/METADATA +1 -1
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/RECORD +9 -8
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/entry_points.txt +1 -0
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/WHEEL +0 -0
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {ladyrick-0.5.0.dist-info → ladyrick-0.5.2.dist-info}/top_level.txt +0 -0
ladyrick/allgather.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
import os
|
2
|
+
import socket
|
3
|
+
import struct
|
4
|
+
import sys
|
5
|
+
import time
|
6
|
+
|
7
|
+
|
8
|
+
class AllGather:
|
9
|
+
def __init__(self):
|
10
|
+
if {"MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} - set(os.environ):
|
11
|
+
raise RuntimeError("MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK env is required")
|
12
|
+
self.master_addr = os.environ["MASTER_ADDR"]
|
13
|
+
self.master_port = int(os.environ["MASTER_PORT"])
|
14
|
+
self.world_size = int(os.environ["WORLD_SIZE"])
|
15
|
+
self.rank = int(os.environ["RANK"])
|
16
|
+
assert 0 <= self.rank < self.world_size
|
17
|
+
|
18
|
+
def allgather(self, data: bytes):
|
19
|
+
if self.world_size == 1:
|
20
|
+
assert self.rank == 0
|
21
|
+
return [data]
|
22
|
+
|
23
|
+
if self.rank == 0:
|
24
|
+
gathered = self._run_master(data)
|
25
|
+
else:
|
26
|
+
gathered = self._run_worker(data)
|
27
|
+
return gathered
|
28
|
+
|
29
|
+
def _run_master(self, data: bytes):
|
30
|
+
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
31
|
+
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
32
|
+
try:
|
33
|
+
server_socket.bind(("0.0.0.0", self.master_port))
|
34
|
+
server_socket.listen(self.world_size - 1)
|
35
|
+
except OSError as e:
|
36
|
+
print(f"Master failed to bind port: {e}")
|
37
|
+
sys.exit(1)
|
38
|
+
|
39
|
+
try:
|
40
|
+
# collect
|
41
|
+
all_data = [data]
|
42
|
+
conns: list[socket.socket] = []
|
43
|
+
for worker_rank in range(1, self.world_size):
|
44
|
+
conn, _ = server_socket.accept()
|
45
|
+
conns.append(conn)
|
46
|
+
cur_data_len = struct.unpack("!I", conn.recv(4))[0]
|
47
|
+
cur_data = conn.recv(cur_data_len)
|
48
|
+
all_data.append(cur_data)
|
49
|
+
|
50
|
+
assert len(all_data) == self.world_size
|
51
|
+
|
52
|
+
# broadcast
|
53
|
+
data_len_pack = struct.pack("!" + "I" * len(all_data), *[len(d) for d in all_data])
|
54
|
+
|
55
|
+
for worker_rank in range(1, self.world_size):
|
56
|
+
conn = conns[worker_rank - 1]
|
57
|
+
conn.sendall(data_len_pack)
|
58
|
+
for d in all_data:
|
59
|
+
conn.sendall(d)
|
60
|
+
conn.close()
|
61
|
+
finally:
|
62
|
+
server_socket.close()
|
63
|
+
return all_data
|
64
|
+
|
65
|
+
def _run_worker(self, data: bytes):
|
66
|
+
worker_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
67
|
+
while True:
|
68
|
+
try:
|
69
|
+
worker_socket.connect((self.master_addr, self.master_port))
|
70
|
+
break
|
71
|
+
except (ConnectionRefusedError, socket.timeout):
|
72
|
+
time.sleep(0.1)
|
73
|
+
|
74
|
+
try:
|
75
|
+
worker_socket.sendall(struct.pack("!I", len(data)))
|
76
|
+
worker_socket.sendall(data)
|
77
|
+
|
78
|
+
all_data: list[bytes] = []
|
79
|
+
data_lens = struct.unpack("!" + "I" * self.world_size, worker_socket.recv(4 * self.world_size))
|
80
|
+
for data_len in data_lens:
|
81
|
+
all_data.append(worker_socket.recv(data_len))
|
82
|
+
finally:
|
83
|
+
worker_socket.close()
|
84
|
+
return all_data
|
85
|
+
|
86
|
+
|
87
|
+
def main():
|
88
|
+
import argparse
|
89
|
+
|
90
|
+
parser = argparse.ArgumentParser(prog="allgather")
|
91
|
+
parser.add_argument("msg", type=str, help="msg to allgather")
|
92
|
+
parser.add_argument("-0", action="store_true", help="separate by \\0", dest="zero")
|
93
|
+
|
94
|
+
args = parser.parse_args()
|
95
|
+
|
96
|
+
allgather = AllGather()
|
97
|
+
gathered = allgather.allgather(args.msg.encode())
|
98
|
+
gathered_str = (s.decode() for s in gathered)
|
99
|
+
if args.zero:
|
100
|
+
print("\0".join(gathered_str), flush=True)
|
101
|
+
else:
|
102
|
+
print("\n".join(gathered_str), flush=True)
|
103
|
+
|
104
|
+
|
105
|
+
if __name__ == "__main__":
|
106
|
+
main()
|
ladyrick/cli/multi_ssh.py
CHANGED
@@ -1,17 +1,9 @@
|
|
1
|
-
import argparse
|
2
|
-
import dataclasses
|
3
|
-
import itertools
|
4
1
|
import json
|
5
2
|
import os
|
6
|
-
import pathlib
|
7
|
-
import random
|
8
3
|
import select
|
9
|
-
import shlex
|
10
4
|
import signal
|
11
5
|
import subprocess
|
12
6
|
import sys
|
13
|
-
import time
|
14
|
-
import uuid
|
15
7
|
|
16
8
|
|
17
9
|
def log(msg):
|
@@ -37,19 +29,18 @@ def remote_head():
|
|
37
29
|
extra_envs = json.loads(sys.argv[2])
|
38
30
|
os.environ.update(extra_envs)
|
39
31
|
|
40
|
-
|
41
|
-
parser.add_argument("cmd", type=str, nargs="+", help="cmd")
|
42
|
-
args = parser.parse_args(sys.argv[3:])
|
32
|
+
cmd = sys.argv[3:]
|
43
33
|
try:
|
44
34
|
from setproctitle import setproctitle
|
45
35
|
|
46
|
-
setproctitle(" ".join([REMOTE_HEAD_PROG_NAME] +
|
36
|
+
setproctitle(" ".join([REMOTE_HEAD_PROG_NAME] + cmd))
|
47
37
|
except ImportError:
|
48
38
|
pass
|
49
39
|
|
50
40
|
# start child process
|
51
41
|
child = subprocess.Popen(
|
52
|
-
|
42
|
+
cmd,
|
43
|
+
stdin=subprocess.PIPE,
|
53
44
|
stdout=sys.stdout,
|
54
45
|
stderr=sys.stderr,
|
55
46
|
start_new_session=True,
|
@@ -97,6 +88,15 @@ if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == REMOTE_HEAD_P
|
|
97
88
|
|
98
89
|
|
99
90
|
# ----- remote_head end ----- #
|
91
|
+
if True:
|
92
|
+
import argparse
|
93
|
+
import dataclasses
|
94
|
+
import itertools
|
95
|
+
import pathlib
|
96
|
+
import random
|
97
|
+
import shlex
|
98
|
+
import time
|
99
|
+
import uuid
|
100
100
|
|
101
101
|
|
102
102
|
@dataclasses.dataclass
|
@@ -158,18 +158,20 @@ class RemoteExecutor:
|
|
158
158
|
def set_envs(cls, executors: list["RemoteExecutor"]):
|
159
159
|
assert executors
|
160
160
|
envs = {}
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
161
|
+
if len(executors) > 1:
|
162
|
+
cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
|
163
|
+
master_ips = subprocess.check_output(cmd).decode().split()
|
164
|
+
priority = {"172": 0, "192": 1, "10": 2}
|
165
|
+
master_addr, cur_p = None, -1
|
166
|
+
for ip in master_ips:
|
167
|
+
prefix = ip.split(".", 1)[0]
|
168
|
+
p = priority.get(prefix, 3)
|
169
|
+
if p > cur_p:
|
170
|
+
master_addr, cur_p = ip, p
|
171
|
+
assert master_addr is not None
|
172
|
+
envs["MASTER_ADDR"] = master_addr
|
173
|
+
else:
|
174
|
+
envs["MASTER_ADDR"] = "127.0.0.1"
|
173
175
|
envs["MASTER_PORT"] = str(random.randint(20000, 40000))
|
174
176
|
envs["WORLD_SIZE"] = str(len(executors))
|
175
177
|
|
@@ -213,10 +215,6 @@ def signal_repeat_checker(sig_to_check, count, duration):
|
|
213
215
|
return checker
|
214
216
|
|
215
217
|
|
216
|
-
def get_common_envs():
|
217
|
-
pass
|
218
|
-
|
219
|
-
|
220
218
|
def main():
|
221
219
|
parser = argparse.ArgumentParser(prog="multi-ssh", add_help=False)
|
222
220
|
parser.add_argument("-h", type=str, action="append", help="hosts to connect. order is 1")
|
@@ -225,13 +223,18 @@ def main():
|
|
225
223
|
parser.add_argument("-l", type=str, help="ssh login User")
|
226
224
|
parser.add_argument("-o", type=str, action="append", help="ssh options")
|
227
225
|
parser.add_argument("-F", type=str, help="ssh config file")
|
228
|
-
parser.add_argument("cmd", type=str, nargs="+", help="cmd")
|
229
226
|
parser.add_argument("--hosts-config", type=str, action="append", help="hosts config string. order is 2")
|
230
227
|
parser.add_argument("--hosts-config-file", type=str, action="append", help="hosts config file. order is 3")
|
231
228
|
parser.add_argument("--help", action="help", default=argparse.SUPPRESS, help="show this help message and exit")
|
229
|
+
parser.add_argument("cmd", type=str, nargs=argparse.REMAINDER, help="cmd")
|
232
230
|
|
233
231
|
args = parser.parse_args()
|
234
232
|
|
233
|
+
if not args.cmd:
|
234
|
+
print("cmd is required\n")
|
235
|
+
parser.print_help()
|
236
|
+
sys.exit(1)
|
237
|
+
|
235
238
|
hosts = [
|
236
239
|
Host(hn, args.F, args.l, args.p, args.i, args.o)
|
237
240
|
for hn in itertools.chain.from_iterable(h.split(",") for h in args.h or [])
|
@@ -255,6 +258,11 @@ def main():
|
|
255
258
|
)
|
256
259
|
)
|
257
260
|
|
261
|
+
if not hosts:
|
262
|
+
print("hosts is required. specify hosts by -h, --hosts-config or --hosts-config-file\n")
|
263
|
+
parser.print_help()
|
264
|
+
sys.exit(1)
|
265
|
+
|
258
266
|
executors = [RemoteExecutor(host, args.cmd) for host in hosts]
|
259
267
|
|
260
268
|
RemoteExecutor.set_envs(executors)
|
ladyrick/cli/psf.py
CHANGED
@@ -5,7 +5,7 @@ import sys
|
|
5
5
|
|
6
6
|
def main():
|
7
7
|
if os.uname().sysname != "Linux":
|
8
|
-
print("
|
8
|
+
print("only support uname Linux")
|
9
9
|
sys.exit(1)
|
10
10
|
verbose = ""
|
11
11
|
root_pid = 0
|
@@ -15,7 +15,7 @@ def main():
|
|
15
15
|
elif arg.isdigit():
|
16
16
|
root_pid = int(arg)
|
17
17
|
else:
|
18
|
-
print(f"
|
18
|
+
print(f"invalid args: {arg}")
|
19
19
|
sys.exit(1)
|
20
20
|
if root_pid in (0, 1):
|
21
21
|
cmd = ["ps", verbose + "afxopid,user,cmd"]
|
@@ -1,4 +1,5 @@
|
|
1
1
|
ladyrick/__init__.py,sha256=udDiOLu2L1jIvjgSCLdDsTnS2OA7kdj2BA9kit2maVo,1343
|
2
|
+
ladyrick/allgather.py,sha256=jhu6eYABIwdl2WTKbK9UR9ji8imhUIdGd-ITPfluYm0,3531
|
2
3
|
ladyrick/debug.py,sha256=jMVdL9cg1yFaX6Yfdw0Yq6LZMOl7av-DnpaR51eEV38,3083
|
3
4
|
ladyrick/loader.py,sha256=Ykg4yxK-UlhAGpBJnnHQLecEMkDpQETOenIs2okwdGY,2263
|
4
5
|
ladyrick/pickle.py,sha256=AjIm4H2_3bPOW-NPd9ow_sA3Z520LZwf1VeX6742ADs,2147
|
@@ -8,8 +9,8 @@ ladyrick/torch.py,sha256=CEdHYaOZ00StZettf4MoIB3tMF0fbzSIH-3pOqOMIZM,1977
|
|
8
9
|
ladyrick/typing.py,sha256=YQeApe63dk7yL4NS5ytlR6v3dLCii2-qsXNlUvjK-zw,203
|
9
10
|
ladyrick/utils.py,sha256=jRRaqC6kNbCJPGeE0YisFgis-wiuINLik1mcUQtytow,608
|
10
11
|
ladyrick/vars.py,sha256=VbFh2u7XybUaBuiYEXBa4sOmoS99vc2AIXdYLBh8vjk,3763
|
11
|
-
ladyrick/cli/multi_ssh.py,sha256=
|
12
|
-
ladyrick/cli/psf.py,sha256=
|
12
|
+
ladyrick/cli/multi_ssh.py,sha256=faRIgradz3e6ZjwpPd77OB_u5FEJFCfVzGq3qrwJsgk,9220
|
13
|
+
ladyrick/cli/psf.py,sha256=JLk3gbPn7E3uuPBbzGvLgJmFQlilA6zg_Xlg7xW5jik,1146
|
13
14
|
ladyrick/cli/tee.py,sha256=UMJxSJLOEfbV43auVKRTIJ5ZAMAkAfj8byiFLk5PUHE,3579
|
14
15
|
ladyrick/cli/test_args.py,sha256=f5sUPDlcf6nbNf6UfLwZQI5g5LN8wlFBQZ10GLw22cg,212
|
15
16
|
ladyrick/cli/test_signal.py,sha256=uvVLbHlvpmYgZcfRF9Lcte0OJJrMDDsmId3BpSxeaOA,490
|
@@ -19,9 +20,9 @@ ladyrick/patch/rich_print.py,sha256=z3Ea1VCunXZvNvEDFHpoyWc8ydINmh-gOIJ1ssscs6s,
|
|
19
20
|
ladyrick/patch/python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
21
|
ladyrick/patch/python/__main__.py,sha256=BAGy1phd26WWcGM9TKHbqIpeZliVofopBndtMIPtDQ0,651
|
21
22
|
ladyrick/patch/python/usercustomize.py,sha256=8mYpcZ8p-l41fiSJue727n8cAmcEmUktObDYZDdLJfs,218
|
22
|
-
ladyrick-0.5.
|
23
|
-
ladyrick-0.5.
|
24
|
-
ladyrick-0.5.
|
25
|
-
ladyrick-0.5.
|
26
|
-
ladyrick-0.5.
|
27
|
-
ladyrick-0.5.
|
23
|
+
ladyrick-0.5.2.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
|
24
|
+
ladyrick-0.5.2.dist-info/METADATA,sha256=K3kyX7wNOmge7CX0mZl7HtOWl3ai9GVFFfOdN7gkUtA,883
|
25
|
+
ladyrick-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
26
|
+
ladyrick-0.5.2.dist-info/entry_points.txt,sha256=WVddlz7Wn9t1kQkA-HluVeeyKRJB3-7KrRpmWxfR2cM,195
|
27
|
+
ladyrick-0.5.2.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
|
28
|
+
ladyrick-0.5.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|