ladyrick 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ladyrick/allgather.py +106 -0
- ladyrick/cli/multi_ssh.py +14 -12
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/METADATA +1 -1
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/RECORD +8 -7
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/entry_points.txt +1 -0
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/WHEEL +0 -0
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {ladyrick-0.5.1.dist-info → ladyrick-0.5.2.dist-info}/top_level.txt +0 -0
ladyrick/allgather.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
import os
|
2
|
+
import socket
|
3
|
+
import struct
|
4
|
+
import sys
|
5
|
+
import time
|
6
|
+
|
7
|
+
|
8
|
+
class AllGather:
|
9
|
+
def __init__(self):
|
10
|
+
if {"MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} - set(os.environ):
|
11
|
+
raise RuntimeError("MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK env is required")
|
12
|
+
self.master_addr = os.environ["MASTER_ADDR"]
|
13
|
+
self.master_port = int(os.environ["MASTER_PORT"])
|
14
|
+
self.world_size = int(os.environ["WORLD_SIZE"])
|
15
|
+
self.rank = int(os.environ["RANK"])
|
16
|
+
assert 0 <= self.rank < self.world_size
|
17
|
+
|
18
|
+
def allgather(self, data: bytes):
|
19
|
+
if self.world_size == 1:
|
20
|
+
assert self.rank == 0
|
21
|
+
return [data]
|
22
|
+
|
23
|
+
if self.rank == 0:
|
24
|
+
gathered = self._run_master(data)
|
25
|
+
else:
|
26
|
+
gathered = self._run_worker(data)
|
27
|
+
return gathered
|
28
|
+
|
29
|
+
def _run_master(self, data: bytes):
|
30
|
+
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
31
|
+
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
32
|
+
try:
|
33
|
+
server_socket.bind(("0.0.0.0", self.master_port))
|
34
|
+
server_socket.listen(self.world_size - 1)
|
35
|
+
except OSError as e:
|
36
|
+
print(f"Master failed to bind port: {e}")
|
37
|
+
sys.exit(1)
|
38
|
+
|
39
|
+
try:
|
40
|
+
# collect
|
41
|
+
all_data = [data]
|
42
|
+
conns: list[socket.socket] = []
|
43
|
+
for worker_rank in range(1, self.world_size):
|
44
|
+
conn, _ = server_socket.accept()
|
45
|
+
conns.append(conn)
|
46
|
+
cur_data_len = struct.unpack("!I", conn.recv(4))[0]
|
47
|
+
cur_data = conn.recv(cur_data_len)
|
48
|
+
all_data.append(cur_data)
|
49
|
+
|
50
|
+
assert len(all_data) == self.world_size
|
51
|
+
|
52
|
+
# broadcast
|
53
|
+
data_len_pack = struct.pack("!" + "I" * len(all_data), *[len(d) for d in all_data])
|
54
|
+
|
55
|
+
for worker_rank in range(1, self.world_size):
|
56
|
+
conn = conns[worker_rank - 1]
|
57
|
+
conn.sendall(data_len_pack)
|
58
|
+
for d in all_data:
|
59
|
+
conn.sendall(d)
|
60
|
+
conn.close()
|
61
|
+
finally:
|
62
|
+
server_socket.close()
|
63
|
+
return all_data
|
64
|
+
|
65
|
+
def _run_worker(self, data: bytes):
|
66
|
+
worker_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
67
|
+
while True:
|
68
|
+
try:
|
69
|
+
worker_socket.connect((self.master_addr, self.master_port))
|
70
|
+
break
|
71
|
+
except (ConnectionRefusedError, socket.timeout):
|
72
|
+
time.sleep(0.1)
|
73
|
+
|
74
|
+
try:
|
75
|
+
worker_socket.sendall(struct.pack("!I", len(data)))
|
76
|
+
worker_socket.sendall(data)
|
77
|
+
|
78
|
+
all_data: list[bytes] = []
|
79
|
+
data_lens = struct.unpack("!" + "I" * self.world_size, worker_socket.recv(4 * self.world_size))
|
80
|
+
for data_len in data_lens:
|
81
|
+
all_data.append(worker_socket.recv(data_len))
|
82
|
+
finally:
|
83
|
+
worker_socket.close()
|
84
|
+
return all_data
|
85
|
+
|
86
|
+
|
87
|
+
def main():
|
88
|
+
import argparse
|
89
|
+
|
90
|
+
parser = argparse.ArgumentParser(prog="allgather")
|
91
|
+
parser.add_argument("msg", type=str, help="msg to allgather")
|
92
|
+
parser.add_argument("-0", action="store_true", help="separate by \\0", dest="zero")
|
93
|
+
|
94
|
+
args = parser.parse_args()
|
95
|
+
|
96
|
+
allgather = AllGather()
|
97
|
+
gathered = allgather.allgather(args.msg.encode())
|
98
|
+
gathered_str = (s.decode() for s in gathered)
|
99
|
+
if args.zero:
|
100
|
+
print("\0".join(gathered_str), flush=True)
|
101
|
+
else:
|
102
|
+
print("\n".join(gathered_str), flush=True)
|
103
|
+
|
104
|
+
|
105
|
+
if __name__ == "__main__":
|
106
|
+
main()
|
ladyrick/cli/multi_ssh.py
CHANGED
@@ -158,18 +158,20 @@ class RemoteExecutor:
|
|
158
158
|
def set_envs(cls, executors: list["RemoteExecutor"]):
|
159
159
|
assert executors
|
160
160
|
envs = {}
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
161
|
+
if len(executors) > 1:
|
162
|
+
cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
|
163
|
+
master_ips = subprocess.check_output(cmd).decode().split()
|
164
|
+
priority = {"172": 0, "192": 1, "10": 2}
|
165
|
+
master_addr, cur_p = None, -1
|
166
|
+
for ip in master_ips:
|
167
|
+
prefix = ip.split(".", 1)[0]
|
168
|
+
p = priority.get(prefix, 3)
|
169
|
+
if p > cur_p:
|
170
|
+
master_addr, cur_p = ip, p
|
171
|
+
assert master_addr is not None
|
172
|
+
envs["MASTER_ADDR"] = master_addr
|
173
|
+
else:
|
174
|
+
envs["MASTER_ADDR"] = "127.0.0.1"
|
173
175
|
envs["MASTER_PORT"] = str(random.randint(20000, 40000))
|
174
176
|
envs["WORLD_SIZE"] = str(len(executors))
|
175
177
|
|
@@ -1,4 +1,5 @@
|
|
1
1
|
ladyrick/__init__.py,sha256=udDiOLu2L1jIvjgSCLdDsTnS2OA7kdj2BA9kit2maVo,1343
|
2
|
+
ladyrick/allgather.py,sha256=jhu6eYABIwdl2WTKbK9UR9ji8imhUIdGd-ITPfluYm0,3531
|
2
3
|
ladyrick/debug.py,sha256=jMVdL9cg1yFaX6Yfdw0Yq6LZMOl7av-DnpaR51eEV38,3083
|
3
4
|
ladyrick/loader.py,sha256=Ykg4yxK-UlhAGpBJnnHQLecEMkDpQETOenIs2okwdGY,2263
|
4
5
|
ladyrick/pickle.py,sha256=AjIm4H2_3bPOW-NPd9ow_sA3Z520LZwf1VeX6742ADs,2147
|
@@ -8,7 +9,7 @@ ladyrick/torch.py,sha256=CEdHYaOZ00StZettf4MoIB3tMF0fbzSIH-3pOqOMIZM,1977
|
|
8
9
|
ladyrick/typing.py,sha256=YQeApe63dk7yL4NS5ytlR6v3dLCii2-qsXNlUvjK-zw,203
|
9
10
|
ladyrick/utils.py,sha256=jRRaqC6kNbCJPGeE0YisFgis-wiuINLik1mcUQtytow,608
|
10
11
|
ladyrick/vars.py,sha256=VbFh2u7XybUaBuiYEXBa4sOmoS99vc2AIXdYLBh8vjk,3763
|
11
|
-
ladyrick/cli/multi_ssh.py,sha256=
|
12
|
+
ladyrick/cli/multi_ssh.py,sha256=faRIgradz3e6ZjwpPd77OB_u5FEJFCfVzGq3qrwJsgk,9220
|
12
13
|
ladyrick/cli/psf.py,sha256=JLk3gbPn7E3uuPBbzGvLgJmFQlilA6zg_Xlg7xW5jik,1146
|
13
14
|
ladyrick/cli/tee.py,sha256=UMJxSJLOEfbV43auVKRTIJ5ZAMAkAfj8byiFLk5PUHE,3579
|
14
15
|
ladyrick/cli/test_args.py,sha256=f5sUPDlcf6nbNf6UfLwZQI5g5LN8wlFBQZ10GLw22cg,212
|
@@ -19,9 +20,9 @@ ladyrick/patch/rich_print.py,sha256=z3Ea1VCunXZvNvEDFHpoyWc8ydINmh-gOIJ1ssscs6s,
|
|
19
20
|
ladyrick/patch/python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
21
|
ladyrick/patch/python/__main__.py,sha256=BAGy1phd26WWcGM9TKHbqIpeZliVofopBndtMIPtDQ0,651
|
21
22
|
ladyrick/patch/python/usercustomize.py,sha256=8mYpcZ8p-l41fiSJue727n8cAmcEmUktObDYZDdLJfs,218
|
22
|
-
ladyrick-0.5.
|
23
|
-
ladyrick-0.5.
|
24
|
-
ladyrick-0.5.
|
25
|
-
ladyrick-0.5.
|
26
|
-
ladyrick-0.5.
|
27
|
-
ladyrick-0.5.
|
23
|
+
ladyrick-0.5.2.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
|
24
|
+
ladyrick-0.5.2.dist-info/METADATA,sha256=K3kyX7wNOmge7CX0mZl7HtOWl3ai9GVFFfOdN7gkUtA,883
|
25
|
+
ladyrick-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
26
|
+
ladyrick-0.5.2.dist-info/entry_points.txt,sha256=WVddlz7Wn9t1kQkA-HluVeeyKRJB3-7KrRpmWxfR2cM,195
|
27
|
+
ladyrick-0.5.2.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
|
28
|
+
ladyrick-0.5.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|