addftool 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
addftool/blob.py CHANGED
@@ -101,12 +101,13 @@ def install_main(args):
101
101
  def mount_main(args):
102
102
  sas_token = get_token(args, info=True)
103
103
 
104
+ cache_gate = 0 if args.no_cache else 1
104
105
  template = {
105
106
  'logging': {'type': 'silent', 'level': 'log_off'},
106
107
  'components': ['libfuse', 'file_cache', 'attr_cache', 'azstorage'],
107
- 'libfuse': {'attribute-expiration-sec': 120, 'entry-expiration-sec': 120, 'negative-entry-expiration-sec': 240},
108
- 'file_cache': {'path': '', 'timeout-sec': 120, 'max-size-mb': 1024 * int(args.file_cache_size)},
109
- 'attr_cache': {'timeout-sec': 7200},
108
+ 'libfuse': {'attribute-expiration-sec': 120 * cache_gate, 'entry-expiration-sec': 120 * cache_gate, 'negative-entry-expiration-sec': 240 * cache_gate},
109
+ 'file_cache': {'path': '', 'timeout-sec': 120 * cache_gate, 'max-size-mb': 1024 * int(args.file_cache_size)},
110
+ 'attr_cache': {'timeout-sec': 7200 * cache_gate},
110
111
  'azstorage': {
111
112
  'type': 'block',
112
113
  'endpoint': '',
@@ -145,6 +146,8 @@ def mount_main(args):
145
146
  print("Create config file: ", temp_config)
146
147
 
147
148
  command = f"blobfuse2 mount {args.mount} --config-file={temp_config}"
149
+ if args.no_cache:
150
+ command += " -o direct_io"
148
151
  if args.allow_other:
149
152
  command += " --allow-other"
150
153
  # to avoid "Error: fusermount3: option allow_other only allowed if 'user_allow_other' is set in /etc/fuse.conf"
@@ -235,6 +238,7 @@ def add_args(parser):
235
238
  mount_parser.add_argument("-o", "--allow-other", help="allow other", action="store_true", default=False)
236
239
  mount_parser.add_argument("--file-cache-size", help="file cache size (GB)", default=256)
237
240
  mount_parser.add_argument("--sudo", help="sudo", action="store_true")
241
+ mount_parser.add_argument("--no_cache", help="no cache", action="store_true")
238
242
 
239
243
  token_parser = subparsers.add_parser('token', help='Token help')
240
244
  add_api(token_parser)
@@ -3,13 +3,20 @@ import time
3
3
  import fnmatch
4
4
  import subprocess
5
5
  import hashlib
6
+ import sys
7
+ import warnings
8
+ try:
9
+ from addftool.ssh import get_client, get_ssh_config, handle_hosts_outputs
10
+ from pssh.clients import ParallelSSHClient
11
+ except ImportError:
12
+ pass
6
13
  from concurrent.futures import ThreadPoolExecutor
7
-
8
- from fabric import Connection, ThreadingGroup
14
+ import gevent
9
15
 
10
16
  try:
11
17
  import torch
12
18
  import torch.distributed as dist
19
+ import torch.multiprocessing as mp
13
20
  from torch.distributed import init_process_group, destroy_process_group
14
21
  _torch_is_available = True
15
22
  except ImportError:
@@ -29,9 +36,8 @@ def add_args(parser):
29
36
 
30
37
  parser.add_argument("--md5_verify", action='store_true', default=False,
31
38
  help="whether to verify the md5 of the file after broadcast, default is False.")
32
- parser.add_argument("--port", help="the port for torchrun, default is 29501", type=int, default=29501)
33
- parser.add_argument("--torchrun_alias", type=str, default="torchrun",
34
- help="the alias of torchrun, default is torchrun. If you use torchrun, please set it to torchrun.")
39
+ parser.add_argument("--port", help="the port for torch, default is 29501", type=int, default=29501)
40
+ parser.add_argument("--python_alias", type=str, default="python")
35
41
  parser.add_argument("--transfer_ranks_per_node", type=int, default=8,
36
42
  help="the number of ranks per node to transfer the files, default is 8.")
37
43
 
@@ -55,28 +61,13 @@ def add_args(parser):
55
61
  help="the blob url to download from, default is empty. " \
56
62
  "Only node-0 will download the files from the blob url, " \
57
63
  "If empty, will transfer the files from the node-0's local folder.")
64
+
65
+ parser.add_argument("--worker_args", type=str, default="")
58
66
 
59
67
  # distributed downloader from blob
60
68
  parser.add_argument("folder", help="the folder need to broadcast", type=str)
61
69
 
62
70
 
63
- class ConnectionWithCommand(Connection):
64
- def __init__(self, host, temp_config_dir, puts, command):
65
- super().__init__(host)
66
- self.command = command
67
- self.puts = puts
68
- self.temp_config_dir = temp_config_dir
69
-
70
- def run(self, command, **kwargs):
71
- super().run(f"mkdir -p {self.temp_config_dir}", **kwargs)
72
- for src, dest in self.puts:
73
- self.put(src, remote=dest)
74
- print(f"Running command on {self.original_host}: {self.command}")
75
- super().run(self.command, **kwargs)
76
- if command:
77
- super().run(command, **kwargs)
78
-
79
-
80
71
  def get_ip_via_ssh(hostname):
81
72
  if hostname == "localhost":
82
73
  return "127.0.0.1"
@@ -150,20 +141,40 @@ def broadcast_folder_main(args):
150
141
  host_list.append(line)
151
142
 
152
143
  print(f"Find {len(host_list)} hosts in hostfile: {args.hostfile}")
153
- connection_list = []
154
144
 
155
145
  remote_temp_config_dir = "/tmp/broadcast_temp_config_dir"
156
-
157
146
  master_addr = get_ip_via_ssh(host_list[0])
147
+
148
+ user_ssh_config, systen_ssh_config = get_ssh_config()
149
+ not_successed_hosts = host_list[:]
150
+ for try_count in range(3):
151
+ client = get_client(not_successed_hosts, user_ssh_configs=user_ssh_config, system_ssh_configs=systen_ssh_config)
152
+ mkdir_cmds = client.run_command(f"mkdir -p {remote_temp_config_dir}", stop_on_errors=False)
153
+ client.join(mkdir_cmds, timeout=10)
154
+ assert isinstance(client, ParallelSSHClient), "Failed to create ParallelSSHClient"
155
+ cmds = client.scp_send(__file__, os.path.join(remote_temp_config_dir, "broadcast_folder.py"))
156
+ gevent.joinall(cmds, raise_error=False, timeout=10)
157
+ resend_hosts = []
158
+ for host, cmd in zip(not_successed_hosts, cmds):
159
+ if cmd.exception:
160
+ print(f"Failed to copy file to {host}: {cmd.exception}")
161
+ resend_hosts.append(host)
162
+
163
+ if len(resend_hosts) == 0:
164
+ break
165
+ not_successed_hosts = resend_hosts
166
+
167
+ if len(resend_hosts) > 0:
168
+ print(f"Failed to copy file to {len(resend_hosts)} hosts: {resend_hosts}")
169
+ sys.exit(1)
158
170
 
171
+ host_commands = []
159
172
  for i, host in enumerate(host_list):
160
- put_commands = []
161
- put_commands.append((__file__, os.path.join(remote_temp_config_dir, "broadcast.py")))
162
- commnads = "NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 "
173
+ commnads = "PYTHONUNBUFFERED=1 NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 "
163
174
  if os.environ.get("SAS_TOKEN") is not None and i == 0:
164
175
  commnads += f"SAS_TOKEN=\"{os.environ['SAS_TOKEN']}\" "
165
- commnads += f"{args.torchrun_alias} --nproc_per_node={args.transfer_ranks_per_node} --nnodes={len(host_list)} --node_rank={i} --master_addr={master_addr} --master_port={args.port}"
166
- commnads += f" {remote_temp_config_dir}/broadcast.py {args.folder} --tool {args.tool} --transfer_ranks_per_node {args.transfer_ranks_per_node} "
176
+ commnads += f"{args.python_alias} {remote_temp_config_dir}/broadcast_folder.py {args.folder} --tool {args.tool} --transfer_ranks_per_node {args.transfer_ranks_per_node} "
177
+ commnads += f" --worker_args {master_addr}_{args.port}_{len(host_list)}_{args.transfer_ranks_per_node}_{i} "
167
178
  if args.contain_md5_files:
168
179
  commnads += " --contain_md5_files"
169
180
  if args.include_string:
@@ -174,11 +185,20 @@ def broadcast_folder_main(args):
174
185
  commnads += f" --from_blob_url {args.from_blob_url}"
175
186
  if args.md5_verify:
176
187
  commnads += " --md5_verify"
177
-
178
- connection_list.append(ConnectionWithCommand(host, remote_temp_config_dir, put_commands, commnads))
188
+ host_commands.append(commnads)
189
+ print(f"Run command on {host}: {commnads}")
190
+
191
+ client = get_client(host_list, user_ssh_configs=user_ssh_config, system_ssh_configs=systen_ssh_config)
192
+ if True:
193
+ cmds = client.run_command("%s", host_args=host_commands, stop_on_errors=False)
194
+ handle_hosts_outputs(cmds)
195
+ for host, cmd in zip(host_list, cmds):
196
+ if cmd.exception or cmd.exit_code != 0:
197
+ print(f"Failed to run command on {host}: {cmd.exception}, EXIT CODE: {cmd.exit_code}")
198
+ sys.exit(1)
199
+ print(f"Command on {host} finished with exit code {cmd.exit_code}")
179
200
 
180
- group = ThreadingGroup.from_connections(connection_list)
181
- group.run('echo "Hello"', hide=False)
201
+ print(f"All nodes finished broadcasting files")
182
202
 
183
203
 
184
204
  def download_files_from_blob(queue, blob_url, sas_token, folder, download_files, node_rank):
@@ -259,16 +279,14 @@ def broadcast_file_from_rank(rank, file_path, from_rank, device, file_size, max_
259
279
  raise ValueError(f"MD5 verification failed for {file_path}: {file_md5} != {src_md5_str}")
260
280
 
261
281
 
262
- def broadcast_folder_worker(args):
282
+ def broadcast_folder_worker(local_rank, node_rank, world_size, master_addr, master_port, args):
263
283
  assert args.tool in ["torch_nccl"], f"tool {args.tool} is not supported"
264
284
  if not _torch_is_available:
265
285
  raise ImportError("Torch is not available. Please install torch to use this feature.")
266
286
  start_time = time.time()
267
287
 
268
- init_process_group(backend='nccl')
269
- global_rank = int(os.environ['RANK'])
270
- local_rank = int(os.environ['LOCAL_RANK'])
271
- world_size = int(os.environ['WORLD_SIZE'])
288
+ global_rank = local_rank + node_rank * args.transfer_ranks_per_node
289
+ init_process_group(backend='nccl', init_method=f"tcp://{master_addr}:{master_port}", rank=global_rank, world_size=world_size)
272
290
  num_nodes = world_size // args.transfer_ranks_per_node
273
291
  worker_rank = local_rank
274
292
 
@@ -291,11 +309,15 @@ def broadcast_folder_worker(args):
291
309
  if global_rank == 0:
292
310
  print(f"Init {len(workers_groups)} worker groups")
293
311
 
294
- print(f"rank {global_rank} start broadcast worker, args = {args}, nccl init time: {time.time() - start_time:.2f}s")
295
-
312
+ if global_rank == 0:
313
+ print(f"rank {global_rank} start broadcast worker, args = {args}, nccl init time: {time.time() - start_time:.2f}s")
314
+ else:
315
+ print(f"rank {global_rank} nccl init time: {time.time() - start_time:.2f}s")
316
+
296
317
  file_size_dict = {}
297
318
 
298
319
  if global_rank == 0:
320
+ warnings.filterwarnings("ignore", category=UserWarning, message="The given buffer is not writable, and PyTorch does not support non-writable tensors.")
299
321
  # Parse include and exclude patterns
300
322
  include_patterns = [p.strip() for p in args.include_string.split(";") if p.strip()]
301
323
  exclude_patterns = [p.strip() for p in args.exclude_string.split(";") if p.strip()]
@@ -397,6 +419,23 @@ def broadcast_folder_worker(args):
397
419
  print(f"Rank {global_rank} finished broadcasting all files, time taken: {time.time() - start_time:.2f}s")
398
420
 
399
421
 
422
+ def broadcast_node_main(args):
423
+ if not _torch_is_available:
424
+ raise ImportError("Torch is not available. Please install torch to use this feature.")
425
+ parts = args.worker_args.split("_")
426
+ master_addr = parts[0]
427
+ master_port = int(parts[1])
428
+ num_nodes = int(parts[2])
429
+ num_ranks_per_node = int(parts[3])
430
+ node_rank = int(parts[4])
431
+ world_size = num_nodes * num_ranks_per_node
432
+
433
+ mp.spawn(
434
+ broadcast_folder_worker, nprocs=num_ranks_per_node, join=True,
435
+ args=(node_rank, world_size, master_addr, master_port, args),
436
+ )
437
+
438
+
400
439
  if __name__ == "__main__":
401
440
  import argparse
402
441
 
@@ -406,4 +445,4 @@ if __name__ == "__main__":
406
445
  if args.hostfile:
407
446
  broadcast_folder_main(args)
408
447
  else:
409
- broadcast_folder_worker(args)
448
+ broadcast_node_main(args)
addftool/ssh/__init__.py CHANGED
@@ -0,0 +1,128 @@
1
+ import re
2
+ import os
3
+ from pssh.clients import ParallelSSHClient
4
+ from pssh.config import HostConfig
5
+ import gevent
6
+
7
+
8
+ def get_host_config(hostname, configs):
9
+ """Get configuration for a specific host from all configs."""
10
+ host_config = {}
11
+
12
+ # Check for exact hostname match
13
+ if hostname in configs:
14
+ host_config.update(configs[hostname])
15
+
16
+ # Check for wildcard matches
17
+ for pattern, config in configs.items():
18
+ if '*' in pattern or '?' in pattern:
19
+ # Convert SSH glob pattern to regex pattern
20
+ regex_pattern = pattern.replace('.', '\\.').replace('*', '.*').replace('?', '.')
21
+ if re.match(f"^{regex_pattern}$", hostname):
22
+ host_config.update(config)
23
+
24
+ return host_config
25
+
26
+
27
+ def parse_ssh_config_file(file_path):
28
+ """Parse SSH config file into a dictionary of host configurations."""
29
+ host_configs = {}
30
+ current_host = None
31
+
32
+ if not os.path.exists(file_path):
33
+ return host_configs
34
+
35
+ try:
36
+ with open(file_path, 'r') as f:
37
+ for line in f:
38
+ line = line.strip()
39
+ if not line or line.startswith('#'):
40
+ continue
41
+
42
+ if line.lower().startswith('host ') and not line.lower().startswith('host *'):
43
+ hosts = line.split()[1:]
44
+ for host in hosts:
45
+ current_host = host
46
+ if current_host not in host_configs:
47
+ host_configs[current_host] = {}
48
+ elif current_host and ' ' in line:
49
+ key, value = line.split(None, 1)
50
+ host_configs[current_host][key.lower()] = value
51
+ except Exception as e:
52
+ print(f"Error reading {file_path}: {e}")
53
+
54
+ return host_configs
55
+
56
+
57
+ def get_ssh_config():
58
+ user_config_path = os.path.expanduser("~/.ssh/config")
59
+ system_config_path = "/etc/ssh/ssh_config"
60
+
61
+ user_configs = parse_ssh_config_file(user_config_path)
62
+ system_configs = parse_ssh_config_file(system_config_path)
63
+ return user_configs, system_configs
64
+
65
+
66
+ def get_client(hosts, user_ssh_configs=None, system_ssh_configs=None):
67
+ if user_ssh_configs is None or system_ssh_configs is None:
68
+ user_ssh_configs, system_ssh_configs = get_ssh_config()
69
+
70
+ to_connect = []
71
+ for hostname in hosts:
72
+ host_ssh_config = get_host_config(hostname, user_ssh_configs)
73
+ if not host_ssh_config:
74
+ host_ssh_config = get_host_config(hostname, system_ssh_configs)
75
+
76
+ if host_ssh_config:
77
+ host_config = HostConfig(
78
+ user=host_ssh_config.get('user', os.getenv('USER', 'root')),
79
+ port=int(host_ssh_config.get('port', 22)),
80
+ private_key=host_ssh_config.get('identityfile', None),
81
+ )
82
+ host_name = host_ssh_config.get('hostname', hostname)
83
+ to_connect.append((host_name, host_config))
84
+ else:
85
+ print(f"No config found for host {hostname}")
86
+
87
+ # Create a ParallelSSHClient with the list of hosts
88
+ client = ParallelSSHClient([host[0] for host in to_connect], host_config=[host[1] for host in to_connect])
89
+ return client
90
+
91
+
92
+ def handle_stream(host, stream_in, stream_name, print_call=None):
93
+ if print_call is None:
94
+ print_call = print
95
+ try:
96
+ if stream_in:
97
+ for line in stream_in:
98
+ prefix = " ERROR" if stream_name == "stderr" else ""
99
+ print_call(f"[{host}]{prefix}: {line}")
100
+ gevent.sleep(0) # 让出控制权
101
+ except Exception as e:
102
+ print(f"[{host}] {stream_name} Exception: {e}")
103
+
104
+
105
+ def _stdout_log(line):
106
+ """Log the output line."""
107
+ print(line)
108
+
109
+
110
+ def _stderr_log(line):
111
+ """Log the error line."""
112
+ print("ERROR: " + line)
113
+
114
+
115
+ def handle_hosts_outputs(hosts_outputs, out_log=None, err_log=None):
116
+ """Handle the outputs from the SSH command execution."""
117
+ if out_log is None:
118
+ out_log = _stdout_log
119
+ if err_log is None:
120
+ err_log = _stderr_log
121
+ jobs = []
122
+ for output in hosts_outputs:
123
+ host_name = output.host
124
+ if output:
125
+ jobs.append(gevent.spawn(handle_stream, host_name, output.stdout, "stdout", out_log))
126
+ jobs.append(gevent.spawn(handle_stream, host_name, output.stderr, "stderr", err_log))
127
+
128
+ gevent.joinall(jobs, raise_error=False)
@@ -1,8 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: addftool
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Requires-Dist: cryptography
5
5
  Requires-Dist: requests
6
6
  Requires-Dist: PyYAML
7
7
  Requires-Dist: psutil
8
8
  Requires-Dist: fabric
9
+ Requires-Dist: gevent
10
+ Requires-Dist: parallel-ssh
@@ -1,7 +1,7 @@
1
1
  addftool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  addftool/addf_portal.py,sha256=6XjwGs5m2mRVDWVvCPOiqn1NlxDcGBTQ9Kr_0g5RsJc,1130
3
- addftool/blob.py,sha256=YT3eZrC9mdYtntDlwEI5MUHFw98wKtBC_pHEdtqvsv4,9206
4
- addftool/broadcast_folder.py,sha256=X9tvvMT7cCDbfKqE7hUGaFbBOHIecIgD8xcT34Bqb_8,17708
3
+ addftool/blob.py,sha256=y1HZaDBUNeXicVytvwpRXwufvvrgxR33ruBlYpxnSa4,9453
4
+ addftool/broadcast_folder.py,sha256=GQBuSL8Ch537V_fSBHesWyqT3KRYry68pbYOKy2bDj4,19619
5
5
  addftool/sync.py,sha256=ZpYxbM8uiPFrV7ODmOaM7asVPCWaxBixA-arVc-1kfs,14045
6
6
  addftool/tool.py,sha256=FmxRY3-pP0_Z0zCUAngjmEMmPUruMftg_iUlB1t2TnQ,2001
7
7
  addftool/util.py,sha256=zlNLu8Be8cGIpNRqBw8_0q7nFxWlsJ9cToN62ohjdXE,2335
@@ -11,9 +11,9 @@ addftool/deploy/ssh_server.py,sha256=7glpJJNROskpqPkeYrTc2MbVzRendUZLv-ZgPs6HCq8
11
11
  addftool/deploy/vscode_server.py,sha256=tLtSvlcK2fEOaw6udWt8dNELVhwv9F59hF5DJJ-1Nak,2666
12
12
  addftool/process/__init__.py,sha256=Dze8OrcyjQlAbPrjE_h8bMi8W4b3OJyZOjTucPrkJvM,3721
13
13
  addftool/process/utils.py,sha256=JldxnwanLJOgxaPgmCJh7SeBRaaj5rFxWWxh1hpsvbA,2609
14
- addftool/ssh/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- addftool-0.2.5.dist-info/METADATA,sha256=e0fYxOCuq9cps5H7TZS-5wxSN3sqrs9e6g2acvJcMSw,170
16
- addftool-0.2.5.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
17
- addftool-0.2.5.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
18
- addftool-0.2.5.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
19
- addftool-0.2.5.dist-info/RECORD,,
14
+ addftool/ssh/__init__.py,sha256=h5_rCO0A6q2Yw9vFguQZZp_ApAJsT1dcnKnbKKZ0cDM,4409
15
+ addftool-0.2.7.dist-info/METADATA,sha256=7wElFYgZp3OX387bAsiQTwzfrc4pHH8dlK6vPnDULWU,220
16
+ addftool-0.2.7.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
17
+ addftool-0.2.7.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
18
+ addftool-0.2.7.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
19
+ addftool-0.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5