ladyrick 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ladyrick/allgather.py ADDED
@@ -0,0 +1,106 @@
1
+ import os
2
+ import socket
3
+ import struct
4
+ import sys
5
+ import time
6
+
7
+
8
+ class AllGather:
9
+ def __init__(self):
10
+ if {"MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} - set(os.environ):
11
+ raise RuntimeError("MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK env is required")
12
+ self.master_addr = os.environ["MASTER_ADDR"]
13
+ self.master_port = int(os.environ["MASTER_PORT"])
14
+ self.world_size = int(os.environ["WORLD_SIZE"])
15
+ self.rank = int(os.environ["RANK"])
16
+ assert 0 <= self.rank < self.world_size
17
+
18
+ def allgather(self, data: bytes):
19
+ if self.world_size == 1:
20
+ assert self.rank == 0
21
+ return [data]
22
+
23
+ if self.rank == 0:
24
+ gathered = self._run_master(data)
25
+ else:
26
+ gathered = self._run_worker(data)
27
+ return gathered
28
+
29
+ def _run_master(self, data: bytes):
30
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
31
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
32
+ try:
33
+ server_socket.bind(("0.0.0.0", self.master_port))
34
+ server_socket.listen(self.world_size - 1)
35
+ except OSError as e:
36
+ print(f"Master failed to bind port: {e}")
37
+ sys.exit(1)
38
+
39
+ try:
40
+ # collect
41
+ all_data = [data]
42
+ conns: list[socket.socket] = []
43
+ for worker_rank in range(1, self.world_size):
44
+ conn, _ = server_socket.accept()
45
+ conns.append(conn)
46
+ cur_data_len = struct.unpack("!I", conn.recv(4))[0]
47
+ cur_data = conn.recv(cur_data_len)
48
+ all_data.append(cur_data)
49
+
50
+ assert len(all_data) == self.world_size
51
+
52
+ # broadcast
53
+ data_len_pack = struct.pack("!" + "I" * len(all_data), *[len(d) for d in all_data])
54
+
55
+ for worker_rank in range(1, self.world_size):
56
+ conn = conns[worker_rank - 1]
57
+ conn.sendall(data_len_pack)
58
+ for d in all_data:
59
+ conn.sendall(d)
60
+ conn.close()
61
+ finally:
62
+ server_socket.close()
63
+ return all_data
64
+
65
+ def _run_worker(self, data: bytes):
66
+ worker_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
67
+ while True:
68
+ try:
69
+ worker_socket.connect((self.master_addr, self.master_port))
70
+ break
71
+ except (ConnectionRefusedError, socket.timeout):
72
+ time.sleep(0.1)
73
+
74
+ try:
75
+ worker_socket.sendall(struct.pack("!I", len(data)))
76
+ worker_socket.sendall(data)
77
+
78
+ all_data: list[bytes] = []
79
+ data_lens = struct.unpack("!" + "I" * self.world_size, worker_socket.recv(4 * self.world_size))
80
+ for data_len in data_lens:
81
+ all_data.append(worker_socket.recv(data_len))
82
+ finally:
83
+ worker_socket.close()
84
+ return all_data
85
+
86
+
87
+ def main():
88
+ import argparse
89
+
90
+ parser = argparse.ArgumentParser(prog="allgather")
91
+ parser.add_argument("msg", type=str, help="msg to allgather")
92
+ parser.add_argument("-0", action="store_true", help="separate by \\0", dest="zero")
93
+
94
+ args = parser.parse_args()
95
+
96
+ allgather = AllGather()
97
+ gathered = allgather.allgather(args.msg.encode())
98
+ gathered_str = (s.decode() for s in gathered)
99
+ if args.zero:
100
+ print("\0".join(gathered_str), flush=True)
101
+ else:
102
+ print("\n".join(gathered_str), flush=True)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
ladyrick/cli/multi_ssh.py CHANGED
@@ -158,18 +158,20 @@ class RemoteExecutor:
158
158
  def set_envs(cls, executors: list["RemoteExecutor"]):
159
159
  assert executors
160
160
  envs = {}
161
-
162
- cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
163
- master_ips = subprocess.check_output(cmd).decode().split()
164
- priority = {"172": 0, "192": 1, "10": 2}
165
- master_addr, cur_p = None, -1
166
- for ip in master_ips:
167
- prefix = ip.split(".", 1)[0]
168
- p = priority.get(prefix, 3)
169
- if p > cur_p:
170
- master_addr, cur_p = ip, p
171
- assert master_addr is not None
172
- envs["MASTER_ADDR"] = master_addr
161
+ if len(executors) > 1:
162
+ cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
163
+ master_ips = subprocess.check_output(cmd).decode().split()
164
+ priority = {"172": 0, "192": 1, "10": 2}
165
+ master_addr, cur_p = None, -1
166
+ for ip in master_ips:
167
+ prefix = ip.split(".", 1)[0]
168
+ p = priority.get(prefix, 3)
169
+ if p > cur_p:
170
+ master_addr, cur_p = ip, p
171
+ assert master_addr is not None
172
+ envs["MASTER_ADDR"] = master_addr
173
+ else:
174
+ envs["MASTER_ADDR"] = "127.0.0.1"
173
175
  envs["MASTER_PORT"] = str(random.randint(20000, 40000))
174
176
  envs["WORLD_SIZE"] = str(len(executors))
175
177
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ladyrick
3
- Version: 0.5.1
3
+ Version: 0.5.2
4
4
  Summary: ladyrick's tools
5
5
  Author-email: ladyrick <ladyrick@qq.com>
6
6
  License-Expression: MIT
@@ -1,4 +1,5 @@
1
1
  ladyrick/__init__.py,sha256=udDiOLu2L1jIvjgSCLdDsTnS2OA7kdj2BA9kit2maVo,1343
2
+ ladyrick/allgather.py,sha256=jhu6eYABIwdl2WTKbK9UR9ji8imhUIdGd-ITPfluYm0,3531
2
3
  ladyrick/debug.py,sha256=jMVdL9cg1yFaX6Yfdw0Yq6LZMOl7av-DnpaR51eEV38,3083
3
4
  ladyrick/loader.py,sha256=Ykg4yxK-UlhAGpBJnnHQLecEMkDpQETOenIs2okwdGY,2263
4
5
  ladyrick/pickle.py,sha256=AjIm4H2_3bPOW-NPd9ow_sA3Z520LZwf1VeX6742ADs,2147
@@ -8,7 +9,7 @@ ladyrick/torch.py,sha256=CEdHYaOZ00StZettf4MoIB3tMF0fbzSIH-3pOqOMIZM,1977
8
9
  ladyrick/typing.py,sha256=YQeApe63dk7yL4NS5ytlR6v3dLCii2-qsXNlUvjK-zw,203
9
10
  ladyrick/utils.py,sha256=jRRaqC6kNbCJPGeE0YisFgis-wiuINLik1mcUQtytow,608
10
11
  ladyrick/vars.py,sha256=VbFh2u7XybUaBuiYEXBa4sOmoS99vc2AIXdYLBh8vjk,3763
11
- ladyrick/cli/multi_ssh.py,sha256=uS2843LPUZTtr08CrXPUyw4PomagMmuLeZNgM-dAUZk,9086
12
+ ladyrick/cli/multi_ssh.py,sha256=faRIgradz3e6ZjwpPd77OB_u5FEJFCfVzGq3qrwJsgk,9220
12
13
  ladyrick/cli/psf.py,sha256=JLk3gbPn7E3uuPBbzGvLgJmFQlilA6zg_Xlg7xW5jik,1146
13
14
  ladyrick/cli/tee.py,sha256=UMJxSJLOEfbV43auVKRTIJ5ZAMAkAfj8byiFLk5PUHE,3579
14
15
  ladyrick/cli/test_args.py,sha256=f5sUPDlcf6nbNf6UfLwZQI5g5LN8wlFBQZ10GLw22cg,212
@@ -19,9 +20,9 @@ ladyrick/patch/rich_print.py,sha256=z3Ea1VCunXZvNvEDFHpoyWc8ydINmh-gOIJ1ssscs6s,
19
20
  ladyrick/patch/python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
21
  ladyrick/patch/python/__main__.py,sha256=BAGy1phd26WWcGM9TKHbqIpeZliVofopBndtMIPtDQ0,651
21
22
  ladyrick/patch/python/usercustomize.py,sha256=8mYpcZ8p-l41fiSJue727n8cAmcEmUktObDYZDdLJfs,218
22
- ladyrick-0.5.1.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
23
- ladyrick-0.5.1.dist-info/METADATA,sha256=YfwVxt9gm68kRBlwdY0DQKIqvuky_uAXw82hiVW5nZQ,883
24
- ladyrick-0.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- ladyrick-0.5.1.dist-info/entry_points.txt,sha256=RlkKx2ecXAbDLfeEI_lIEvC21ZyEGli8hZmu8hIYE-E,159
26
- ladyrick-0.5.1.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
27
- ladyrick-0.5.1.dist-info/RECORD,,
23
+ ladyrick-0.5.2.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
24
+ ladyrick-0.5.2.dist-info/METADATA,sha256=K3kyX7wNOmge7CX0mZl7HtOWl3ai9GVFFfOdN7gkUtA,883
25
+ ladyrick-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
+ ladyrick-0.5.2.dist-info/entry_points.txt,sha256=WVddlz7Wn9t1kQkA-HluVeeyKRJB3-7KrRpmWxfR2cM,195
27
+ ladyrick-0.5.2.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
28
+ ladyrick-0.5.2.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ allgather = ladyrick.allgather:main
2
3
  ladyrick-tee = ladyrick.cli.tee:main
3
4
  multi-ssh = ladyrick.cli.multi_ssh:main
4
5
  pretty-print = ladyrick.pprint:main