ladyrick 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ladyrick/allgather.py ADDED
@@ -0,0 +1,106 @@
1
+ import os
2
+ import socket
3
+ import struct
4
+ import sys
5
+ import time
6
+
7
+
8
+ class AllGather:
9
+ def __init__(self):
10
+ if {"MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} - set(os.environ):
11
+ raise RuntimeError("MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK env is required")
12
+ self.master_addr = os.environ["MASTER_ADDR"]
13
+ self.master_port = int(os.environ["MASTER_PORT"])
14
+ self.world_size = int(os.environ["WORLD_SIZE"])
15
+ self.rank = int(os.environ["RANK"])
16
+ assert 0 <= self.rank < self.world_size
17
+
18
+ def allgather(self, data: bytes):
19
+ if self.world_size == 1:
20
+ assert self.rank == 0
21
+ return [data]
22
+
23
+ if self.rank == 0:
24
+ gathered = self._run_master(data)
25
+ else:
26
+ gathered = self._run_worker(data)
27
+ return gathered
28
+
29
+ def _run_master(self, data: bytes):
30
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
31
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
32
+ try:
33
+ server_socket.bind(("0.0.0.0", self.master_port))
34
+ server_socket.listen(self.world_size - 1)
35
+ except OSError as e:
36
+ print(f"Master failed to bind port: {e}")
37
+ sys.exit(1)
38
+
39
+ try:
40
+ # collect
41
+ all_data = [data]
42
+ conns: list[socket.socket] = []
43
+ for worker_rank in range(1, self.world_size):
44
+ conn, _ = server_socket.accept()
45
+ conns.append(conn)
46
+ cur_data_len = struct.unpack("!I", conn.recv(4))[0]
47
+ cur_data = conn.recv(cur_data_len)
48
+ all_data.append(cur_data)
49
+
50
+ assert len(all_data) == self.world_size
51
+
52
+ # broadcast
53
+ data_len_pack = struct.pack("!" + "I" * len(all_data), *[len(d) for d in all_data])
54
+
55
+ for worker_rank in range(1, self.world_size):
56
+ conn = conns[worker_rank - 1]
57
+ conn.sendall(data_len_pack)
58
+ for d in all_data:
59
+ conn.sendall(d)
60
+ conn.close()
61
+ finally:
62
+ server_socket.close()
63
+ return all_data
64
+
65
+ def _run_worker(self, data: bytes):
66
+ worker_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
67
+ while True:
68
+ try:
69
+ worker_socket.connect((self.master_addr, self.master_port))
70
+ break
71
+ except (ConnectionRefusedError, socket.timeout):
72
+ time.sleep(0.1)
73
+
74
+ try:
75
+ worker_socket.sendall(struct.pack("!I", len(data)))
76
+ worker_socket.sendall(data)
77
+
78
+ all_data: list[bytes] = []
79
+ data_lens = struct.unpack("!" + "I" * self.world_size, worker_socket.recv(4 * self.world_size))
80
+ for data_len in data_lens:
81
+ all_data.append(worker_socket.recv(data_len))
82
+ finally:
83
+ worker_socket.close()
84
+ return all_data
85
+
86
+
87
+ def main():
88
+ import argparse
89
+
90
+ parser = argparse.ArgumentParser(prog="allgather")
91
+ parser.add_argument("msg", type=str, help="msg to allgather")
92
+ parser.add_argument("-0", action="store_true", help="separate by \\0", dest="zero")
93
+
94
+ args = parser.parse_args()
95
+
96
+ allgather = AllGather()
97
+ gathered = allgather.allgather(args.msg.encode())
98
+ gathered_str = (s.decode() for s in gathered)
99
+ if args.zero:
100
+ print("\0".join(gathered_str), flush=True)
101
+ else:
102
+ print("\n".join(gathered_str), flush=True)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
ladyrick/cli/multi_ssh.py CHANGED
@@ -1,17 +1,9 @@
1
- import argparse
2
- import dataclasses
3
- import itertools
4
1
  import json
5
2
  import os
6
- import pathlib
7
- import random
8
3
  import select
9
- import shlex
10
4
  import signal
11
5
  import subprocess
12
6
  import sys
13
- import time
14
- import uuid
15
7
 
16
8
 
17
9
  def log(msg):
@@ -37,19 +29,18 @@ def remote_head():
37
29
  extra_envs = json.loads(sys.argv[2])
38
30
  os.environ.update(extra_envs)
39
31
 
40
- parser = argparse.ArgumentParser(prog=REMOTE_HEAD_PROG_NAME)
41
- parser.add_argument("cmd", type=str, nargs="+", help="cmd")
42
- args = parser.parse_args(sys.argv[3:])
32
+ cmd = sys.argv[3:]
43
33
  try:
44
34
  from setproctitle import setproctitle
45
35
 
46
- setproctitle(" ".join([REMOTE_HEAD_PROG_NAME] + args.cmd))
36
+ setproctitle(" ".join([REMOTE_HEAD_PROG_NAME] + cmd))
47
37
  except ImportError:
48
38
  pass
49
39
 
50
40
  # start child process
51
41
  child = subprocess.Popen(
52
- args.cmd,
42
+ cmd,
43
+ stdin=subprocess.PIPE,
53
44
  stdout=sys.stdout,
54
45
  stderr=sys.stderr,
55
46
  start_new_session=True,
@@ -97,6 +88,15 @@ if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == REMOTE_HEAD_P
97
88
 
98
89
 
99
90
  # ----- remote_head end ----- #
91
+ if True:
92
+ import argparse
93
+ import dataclasses
94
+ import itertools
95
+ import pathlib
96
+ import random
97
+ import shlex
98
+ import time
99
+ import uuid
100
100
 
101
101
 
102
102
  @dataclasses.dataclass
@@ -158,18 +158,20 @@ class RemoteExecutor:
158
158
  def set_envs(cls, executors: list["RemoteExecutor"]):
159
159
  assert executors
160
160
  envs = {}
161
-
162
- cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
163
- master_ips = subprocess.check_output(cmd).decode().split()
164
- priority = {"172": 0, "192": 1, "10": 2}
165
- master_addr, cur_p = None, -1
166
- for ip in master_ips:
167
- prefix = ip.split(".", 1)[0]
168
- p = priority.get(prefix, 3)
169
- if p > cur_p:
170
- master_addr, cur_p = ip, p
171
- assert master_addr is not None
172
- envs["MASTER_ADDR"] = master_addr
161
+ if len(executors) > 1:
162
+ cmd = cls.make_ssh_cmd(executors[0].host, "hostname -I")
163
+ master_ips = subprocess.check_output(cmd).decode().split()
164
+ priority = {"172": 0, "192": 1, "10": 2}
165
+ master_addr, cur_p = None, -1
166
+ for ip in master_ips:
167
+ prefix = ip.split(".", 1)[0]
168
+ p = priority.get(prefix, 3)
169
+ if p > cur_p:
170
+ master_addr, cur_p = ip, p
171
+ assert master_addr is not None
172
+ envs["MASTER_ADDR"] = master_addr
173
+ else:
174
+ envs["MASTER_ADDR"] = "127.0.0.1"
173
175
  envs["MASTER_PORT"] = str(random.randint(20000, 40000))
174
176
  envs["WORLD_SIZE"] = str(len(executors))
175
177
 
@@ -213,10 +215,6 @@ def signal_repeat_checker(sig_to_check, count, duration):
213
215
  return checker
214
216
 
215
217
 
216
- def get_common_envs():
217
- pass
218
-
219
-
220
218
  def main():
221
219
  parser = argparse.ArgumentParser(prog="multi-ssh", add_help=False)
222
220
  parser.add_argument("-h", type=str, action="append", help="hosts to connect. order is 1")
@@ -225,13 +223,18 @@ def main():
225
223
  parser.add_argument("-l", type=str, help="ssh login User")
226
224
  parser.add_argument("-o", type=str, action="append", help="ssh options")
227
225
  parser.add_argument("-F", type=str, help="ssh config file")
228
- parser.add_argument("cmd", type=str, nargs="+", help="cmd")
229
226
  parser.add_argument("--hosts-config", type=str, action="append", help="hosts config string. order is 2")
230
227
  parser.add_argument("--hosts-config-file", type=str, action="append", help="hosts config file. order is 3")
231
228
  parser.add_argument("--help", action="help", default=argparse.SUPPRESS, help="show this help message and exit")
229
+ parser.add_argument("cmd", type=str, nargs=argparse.REMAINDER, help="cmd")
232
230
 
233
231
  args = parser.parse_args()
234
232
 
233
+ if not args.cmd:
234
+ print("cmd is required\n")
235
+ parser.print_help()
236
+ sys.exit(1)
237
+
235
238
  hosts = [
236
239
  Host(hn, args.F, args.l, args.p, args.i, args.o)
237
240
  for hn in itertools.chain.from_iterable(h.split(",") for h in args.h or [])
@@ -255,6 +258,11 @@ def main():
255
258
  )
256
259
  )
257
260
 
261
+ if not hosts:
262
+ print("hosts is required. specify hosts by -h, --hosts-config or --hosts-config-file\n")
263
+ parser.print_help()
264
+ sys.exit(1)
265
+
258
266
  executors = [RemoteExecutor(host, args.cmd) for host in hosts]
259
267
 
260
268
  RemoteExecutor.set_envs(executors)
ladyrick/cli/psf.py CHANGED
@@ -5,7 +5,7 @@ import sys
5
5
 
6
6
  def main():
7
7
  if os.uname().sysname != "Linux":
8
- print("Only support uname Linux")
8
+ print("only support uname Linux")
9
9
  sys.exit(1)
10
10
  verbose = ""
11
11
  root_pid = 0
@@ -15,7 +15,7 @@ def main():
15
15
  elif arg.isdigit():
16
16
  root_pid = int(arg)
17
17
  else:
18
- print(f"Invalid args: {arg}")
18
+ print(f"invalid args: {arg}")
19
19
  sys.exit(1)
20
20
  if root_pid in (0, 1):
21
21
  cmd = ["ps", verbose + "afxopid,user,cmd"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ladyrick
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: ladyrick's tools
5
5
  Author-email: ladyrick <ladyrick@qq.com>
6
6
  License-Expression: MIT
@@ -1,4 +1,5 @@
1
1
  ladyrick/__init__.py,sha256=udDiOLu2L1jIvjgSCLdDsTnS2OA7kdj2BA9kit2maVo,1343
2
+ ladyrick/allgather.py,sha256=jhu6eYABIwdl2WTKbK9UR9ji8imhUIdGd-ITPfluYm0,3531
2
3
  ladyrick/debug.py,sha256=jMVdL9cg1yFaX6Yfdw0Yq6LZMOl7av-DnpaR51eEV38,3083
3
4
  ladyrick/loader.py,sha256=Ykg4yxK-UlhAGpBJnnHQLecEMkDpQETOenIs2okwdGY,2263
4
5
  ladyrick/pickle.py,sha256=AjIm4H2_3bPOW-NPd9ow_sA3Z520LZwf1VeX6742ADs,2147
@@ -8,8 +9,8 @@ ladyrick/torch.py,sha256=CEdHYaOZ00StZettf4MoIB3tMF0fbzSIH-3pOqOMIZM,1977
8
9
  ladyrick/typing.py,sha256=YQeApe63dk7yL4NS5ytlR6v3dLCii2-qsXNlUvjK-zw,203
9
10
  ladyrick/utils.py,sha256=jRRaqC6kNbCJPGeE0YisFgis-wiuINLik1mcUQtytow,608
10
11
  ladyrick/vars.py,sha256=VbFh2u7XybUaBuiYEXBa4sOmoS99vc2AIXdYLBh8vjk,3763
11
- ladyrick/cli/multi_ssh.py,sha256=T585QpD2aSt-XKELHdiF67ukc-jKf1Q43ipA1jyOZU8,8923
12
- ladyrick/cli/psf.py,sha256=L37rvhODzcdRsqWM_G-xkem_f783ggiYfc95SHcELp0,1146
12
+ ladyrick/cli/multi_ssh.py,sha256=faRIgradz3e6ZjwpPd77OB_u5FEJFCfVzGq3qrwJsgk,9220
13
+ ladyrick/cli/psf.py,sha256=JLk3gbPn7E3uuPBbzGvLgJmFQlilA6zg_Xlg7xW5jik,1146
13
14
  ladyrick/cli/tee.py,sha256=UMJxSJLOEfbV43auVKRTIJ5ZAMAkAfj8byiFLk5PUHE,3579
14
15
  ladyrick/cli/test_args.py,sha256=f5sUPDlcf6nbNf6UfLwZQI5g5LN8wlFBQZ10GLw22cg,212
15
16
  ladyrick/cli/test_signal.py,sha256=uvVLbHlvpmYgZcfRF9Lcte0OJJrMDDsmId3BpSxeaOA,490
@@ -19,9 +20,9 @@ ladyrick/patch/rich_print.py,sha256=z3Ea1VCunXZvNvEDFHpoyWc8ydINmh-gOIJ1ssscs6s,
19
20
  ladyrick/patch/python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
21
  ladyrick/patch/python/__main__.py,sha256=BAGy1phd26WWcGM9TKHbqIpeZliVofopBndtMIPtDQ0,651
21
22
  ladyrick/patch/python/usercustomize.py,sha256=8mYpcZ8p-l41fiSJue727n8cAmcEmUktObDYZDdLJfs,218
22
- ladyrick-0.5.0.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
23
- ladyrick-0.5.0.dist-info/METADATA,sha256=LqfRIxtvCt1XMtpvjjacYbq7qBPoSiZMK8sUdmOs8Pw,883
24
- ladyrick-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- ladyrick-0.5.0.dist-info/entry_points.txt,sha256=RlkKx2ecXAbDLfeEI_lIEvC21ZyEGli8hZmu8hIYE-E,159
26
- ladyrick-0.5.0.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
27
- ladyrick-0.5.0.dist-info/RECORD,,
23
+ ladyrick-0.5.2.dist-info/licenses/LICENSE,sha256=EeNAFxYAOYEmo2YEM7Zk5Oknq4RI0XMAbk4Rgoem6fs,1065
24
+ ladyrick-0.5.2.dist-info/METADATA,sha256=K3kyX7wNOmge7CX0mZl7HtOWl3ai9GVFFfOdN7gkUtA,883
25
+ ladyrick-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
+ ladyrick-0.5.2.dist-info/entry_points.txt,sha256=WVddlz7Wn9t1kQkA-HluVeeyKRJB3-7KrRpmWxfR2cM,195
27
+ ladyrick-0.5.2.dist-info/top_level.txt,sha256=RIC3-Jty2qzLYXSOr7fOu1loTwlMU9cF6MFeGIROxWU,9
28
+ ladyrick-0.5.2.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ allgather = ladyrick.allgather:main
2
3
  ladyrick-tee = ladyrick.cli.tee:main
3
4
  multi-ssh = ladyrick.cli.multi_ssh:main
4
5
  pretty-print = ladyrick.pprint:main