addftool 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
addftool/deploy/azure.py CHANGED
@@ -11,6 +11,7 @@ def deploy_azure(packages):
11
11
  command_prefix = "sudo " if need_sudo() else ""
12
12
  command = "dpkg -i /tmp/packages-microsoft-prod.deb"
13
13
  execute_command(command_prefix + command)
14
+ execute_command(command_prefix + "apt-get update")
14
15
 
15
16
  install_packages(packages)
16
17
 
addftool/sync.py ADDED
@@ -0,0 +1,341 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import subprocess
5
+ import hashlib
6
+ import tempfile
7
+ from multiprocessing import Queue, Process
8
+
9
+ try:
10
+ from fabric import Connection, ThreadingGroup
11
+ except ImportError:
12
+ Connection = object
13
+
14
+ try:
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch.distributed import init_process_group, destroy_process_group
18
+ _torch_is_available = True
19
+ except ImportError:
20
+ _torch_is_available = False
21
+
22
+
23
+ def add_sync_args(subparsers):
24
+ deploy_parser = subparsers.add_parser('sync', help='download and sync folder from master node to other nodes')
25
+ add_args(deploy_parser)
26
+
27
+
28
+ def add_args(parser):
29
+ parser.add_argument("--from_blob_url", help="download from blob url to master node before sync", type=str, default="")
30
+ parser.add_argument("--tool", help="tool name", type=str, default="torch_nccl", choices=["torch_nccl"])
31
+ parser.add_argument("--hostfile", help="host file, sync file from node-0 to others", type=str, default="")
32
+
33
+ parser.add_argument("--download_index_file", type=str, default="",
34
+ help="the file to save the download index, should be generated by master node.")
35
+ parser.add_argument("--md5_verify", action='store_true', default=False,
36
+ help="whether to verify the md5 of the file after sync, default is False.")
37
+ parser.add_argument("--port", help="the port for torchrun, default is 29501", type=int, default=29501)
38
+ parser.add_argument("--torchrun_alias", type=str, default="torchrun",
39
+ help="the alias of torchrun, default is torchrun. If you use torchrun, please set it to torchrun.")
40
+ # distributed downloader from blob
41
+ parser.add_argument("--donwload_nodes", help="download nodes, default is node-0", type=int, default=1)
42
+ parser.add_argument("folder", help="the folder need to sync", type=str)
43
+
44
+
45
+ class ConnectionWithCommand(Connection):
46
+ def __init__(self, host, temp_config_dir, puts, command):
47
+ super().__init__(host)
48
+ self.command = command
49
+ self.puts = puts
50
+ self.temp_config_dir = temp_config_dir
51
+
52
+ def run(self, command, **kwargs):
53
+ super().run(f"mkdir -p {self.temp_config_dir}", **kwargs)
54
+ for src, dest in self.puts:
55
+ self.put(src, remote=dest)
56
+ super().run(self.command, **kwargs)
57
+ if command:
58
+ super().run(command, **kwargs)
59
+
60
+
61
+ def get_ip_via_ssh(hostname):
62
+ if hostname == "localhost":
63
+ return "127.0.0.1"
64
+ try:
65
+ cmd = ["ssh", hostname, "hostname -I | awk '{print $1}'"]
66
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
67
+
68
+ if result.returncode == 0:
69
+ ip = result.stdout.strip()
70
+ return ip
71
+ else:
72
+ print(f"SSH {hostname} failed: {result.stderr}")
73
+ return None
74
+ except Exception as e:
75
+ print(f"Error executing SSH command on {hostname}: {e}")
76
+ return None
77
+
78
+
79
+ def sync_main(args):
80
+ sas_token = os.environ.get("SAS_TOKEN")
81
+ if not sas_token:
82
+ raise ValueError("SAS_TOKEN environment variable is not set.")
83
+
84
+ try:
85
+ list_operation = subprocess.run(
86
+ ["azcopy", "list", args.from_blob_url + sas_token, "--machine-readable"],
87
+ stdout=subprocess.PIPE,
88
+ stderr=subprocess.PIPE,
89
+ text=True
90
+ )
91
+ if list_operation.returncode != 0:
92
+ raise RuntimeError(f"Failed to list blob: {list_operation.stderr}")
93
+ except Exception as e:
94
+ raise RuntimeError(f"Error while listing blob: {e}")
95
+
96
+ file_size_list = {}
97
+ for line in list_operation.stdout.splitlines():
98
+ # print(line)
99
+ parts = line.split("; Content Length:")
100
+ if len(parts) != 2:
101
+ print(f"INFO: {line}")
102
+ continue
103
+ file_name = parts[0].strip()
104
+ file_size = int(parts[1])
105
+ file_size_list[file_name] = file_size
106
+
107
+ # divide the files into chunks for each node by file size
108
+ sorted_files = sorted(file_size_list.items(), key=lambda x: x[1], reverse=True)
109
+ zero_files = []
110
+
111
+ groups = [[] for _ in range(args.donwload_nodes)]
112
+ for i, (file_name, file_size) in enumerate(sorted_files):
113
+ if file_size == 0:
114
+ zero_files.append(file_name)
115
+ continue
116
+ groups[i % args.donwload_nodes].append(file_name)
117
+
118
+ # create a temp folder to save the downloaded files
119
+ local_temp_config_dir = tempfile.mktemp()
120
+ os.makedirs(local_temp_config_dir, exist_ok=True)
121
+ print(f"Temp config dir: {local_temp_config_dir}")
122
+
123
+ for i, group in enumerate(groups):
124
+ group_file_path = os.path.join(local_temp_config_dir, f"node_{i}.txt")
125
+ total_size = 0
126
+ with open(group_file_path, "w") as f:
127
+ for file_name in group:
128
+ f.write(file_name + "\n")
129
+ total_size += file_size_list[file_name]
130
+ print(f"Node-{i} will download {len(group)} files, total size: {total_size} bytes")
131
+
132
+ with open(os.path.join(local_temp_config_dir, "zero_files.txt"), "w") as f:
133
+ for file_name in zero_files:
134
+ f.write(file_name + "\n")
135
+
136
+ print(f"Detect {len(zero_files)} files with size 0 bytes, they will be special handled.")
137
+
138
+ with open(args.hostfile, "r") as f:
139
+ host_list = []
140
+ for line in f:
141
+ line = line.strip()
142
+ if line and not line.startswith("#"):
143
+ host_list.append(line)
144
+
145
+ if len(host_list) < len(groups):
146
+ raise ValueError(f"Number of hosts in hostfile {len(host_list)} is less than number of download nodes {len(groups)}")
147
+
148
+ print(f"Find {len(host_list)} hosts in hostfile: {args.hostfile}")
149
+ connection_list = []
150
+
151
+ # avoid the temp_config_dir to be the same as remote_temp_config_dir
152
+ remote_temp_config_dir = tempfile.mktemp()
153
+ while remote_temp_config_dir == local_temp_config_dir:
154
+ remote_temp_config_dir = tempfile.mktemp()
155
+
156
+ master_addr = get_ip_via_ssh(host_list[0])
157
+ for i, host in enumerate(host_list):
158
+ # copy this .py file to the remote host
159
+ put_commands = []
160
+ put_commands.append((__file__, os.path.join(remote_temp_config_dir, "sync.py")))
161
+ if i < args.donwload_nodes:
162
+ local_group_file = os.path.join(local_temp_config_dir, f"node_{i}.txt")
163
+ put_commands.append((local_group_file, os.path.join(remote_temp_config_dir, f"node_{i}.txt")))
164
+ put_commands.append((os.path.join(local_temp_config_dir, "zero_files.txt"), os.path.join(remote_temp_config_dir, "zero_files.txt")))
165
+
166
+ commnads = f"export SAS_TOKEN=\"{sas_token}\""
167
+ commnads += f" && {args.torchrun_alias} --nproc_per_node=1 --nnodes={len(host_list)} --node_rank={i} --master_addr={master_addr} --master_port={args.port}"
168
+ commnads += f" {remote_temp_config_dir}/sync.py {args.folder} --tool {args.tool} --from_blob_url {args.from_blob_url}"
169
+ if args.md5_verify:
170
+ commnads += " --md5_verify"
171
+ if i < args.donwload_nodes:
172
+ commnads += f" --download_index_file {remote_temp_config_dir}/node_{i}.txt"
173
+
174
+ connection_list.append(ConnectionWithCommand(host, remote_temp_config_dir, put_commands, commnads))
175
+
176
+ group = ThreadingGroup.from_connections(connection_list)
177
+ group.run('echo "Hello"', hide=False)
178
+
179
+
180
+ def download_files_from_blob(queue, blob_url, sas_token, folder, download_files, node_rank):
181
+ # This function should implement the logic to download files from blob storage
182
+ # using the provided blob_url and sas_token. The downloaded files should be
183
+ # saved in the specified folder.
184
+ if not blob_url.endswith("/"):
185
+ blob_url += "/"
186
+ print(f"Node-{node_rank} start downloading {len(download_files)} files from {blob_url} to {folder}")
187
+ for file_name in download_files:
188
+ file_path = os.path.join(folder, file_name)
189
+ file_dir = os.path.dirname(file_path)
190
+ if not os.path.exists(file_dir):
191
+ os.makedirs(file_dir, exist_ok=True)
192
+ for try_count in range(3):
193
+ try:
194
+ download_status = subprocess.run(
195
+ ["azcopy", "copy", blob_url + file_name + sas_token, file_path],
196
+ stdout=subprocess.PIPE,
197
+ stderr=subprocess.PIPE,
198
+ text=True
199
+ )
200
+ if download_status.returncode != 0:
201
+ raise RuntimeError(f"Failed to download {file_name}: {download_status.stderr}")
202
+ print(f"Rank {node_rank}: Downloaded {file_name} successfully, from {blob_url} to {file_path}")
203
+ queue.put(file_path)
204
+ break
205
+ except Exception as e:
206
+ print(f"Rank {node_rank}: Download failed: {e}")
207
+
208
+
209
+ def sync_file_from_rank(rank, file_path, from_rank, md5_verify=False):
210
+ if rank == from_rank:
211
+ with open(file_path, "rb") as f:
212
+ data = f.read()
213
+ num_bytes = len(data)
214
+ if md5_verify:
215
+ md5 = hashlib.md5()
216
+ md5.update(data)
217
+ md5_value = md5.hexdigest()
218
+ else:
219
+ md5_value = ""
220
+ obj_list = [file_path, num_bytes, md5_value]
221
+ dist.broadcast_object_list(obj_list, src=from_rank)
222
+ tensor = torch.frombuffer(data, dtype=torch.uint8)
223
+ tensor = tensor.cuda()
224
+ else:
225
+ obj_list = [0, "", ""]
226
+ dist.broadcast_object_list(obj_list, src=from_rank)
227
+ file_path, num_bytes, md5_value = obj_list
228
+ tensor = torch.empty(num_bytes, dtype=torch.uint8, device='cuda')
229
+
230
+ dist.broadcast(tensor, src=from_rank)
231
+ if rank != from_rank:
232
+ file_dir = os.path.dirname(file_path)
233
+ if not os.path.exists(file_dir):
234
+ os.makedirs(file_dir, exist_ok=True)
235
+ with open(file_path, "wb") as f:
236
+ tensor.cpu().numpy().tofile(f)
237
+ if md5_verify:
238
+ md5 = hashlib.md5()
239
+ md5.update(tensor.cpu().numpy())
240
+ md5_value_recv = md5.hexdigest()
241
+ if md5_value_recv != md5_value:
242
+ raise ValueError(f"MD5 mismatch for file {file_path}: {md5_value_recv} != {md5_value}")
243
+ else:
244
+ print(f"Node-{rank} verified file {file_path} with MD5: {md5_value_recv}")
245
+
246
+
247
+ def sync_worker(args):
248
+ assert args.tool in ["torch_nccl"], f"tool {args.tool} is not supported"
249
+ if not _torch_is_available:
250
+ raise ImportError("Torch is not available. Please install torch to use this feature.")
251
+ start_time = time.time()
252
+
253
+ init_process_group(backend='nccl')
254
+ node_rank = int(os.environ['RANK'])
255
+ world_size = int(os.environ['WORLD_SIZE'])
256
+
257
+ print(f"rank {node_rank} start sync worker, args = {args}, nccl init time: {time.time() - start_time:.2f}s")
258
+
259
+ if world_size < 2:
260
+ raise ValueError("World size must be at least 2 for distributed download.")
261
+
262
+ download_queue = Queue()
263
+
264
+ download_files = []
265
+ transfered_files = set()
266
+ if args.download_index_file:
267
+ with open(args.download_index_file, "r") as f:
268
+ for line in f:
269
+ download_files.append(line.strip())
270
+
271
+ download_process = Process(
272
+ target=download_files_from_blob,
273
+ args=(download_queue, args.from_blob_url, os.environ["SAS_TOKEN"], args.folder, download_files, node_rank),
274
+ )
275
+ download_process.start()
276
+
277
+ last_download = None
278
+
279
+ while True:
280
+ if len(download_files) == len(transfered_files):
281
+ status_code = world_size + 1
282
+ elif last_download is not None:
283
+ status_code = node_rank
284
+ else:
285
+ try:
286
+ last_download = download_queue.get(timeout=1)
287
+ status_code = node_rank
288
+ except Exception as e:
289
+ status_code = world_size
290
+
291
+ global_status_code = torch.tensor(status_code).cuda()
292
+ dist.all_reduce(global_status_code, op=dist.ReduceOp.MIN)
293
+ global_status_code = global_status_code.item()
294
+
295
+ if global_status_code == world_size + 1:
296
+ print(f"Node-{node_rank} finished downloading all files, time taken: {time.time() - start_time:.2f}s")
297
+ break
298
+ elif global_status_code == world_size:
299
+ if node_rank == 0:
300
+ print(f"All nodes is waiting for other nodes to finish downloading...")
301
+ time.sleep(1)
302
+ elif global_status_code == node_rank:
303
+ print(f"Node-{node_rank} is downloaded {last_download}, prepare to broadcast it..., time taken: {time.time() - start_time:.2f}s")
304
+ sync_file_from_rank(node_rank, last_download, node_rank, md5_verify=args.md5_verify)
305
+ transfered_files.add(last_download)
306
+ last_download = None
307
+ else:
308
+ sync_file_from_rank(node_rank, "", global_status_code, md5_verify=args.md5_verify)
309
+
310
+ print(f"Node-{node_rank} finished downloading files, time taken: {time.time() - start_time:.2f}s")
311
+ dist.barrier()
312
+ download_process.join()
313
+ destroy_process_group()
314
+
315
+ # current directory
316
+ zero_file = os.path.join(__file__, "zero_files.txt")
317
+ if os.path.exists(zero_file):
318
+ with open(zero_file, "r") as f:
319
+ zero_files = [line.strip() for line in f]
320
+ for zero_file_name in zero_files:
321
+ zero_file_path = os.path.join(args.folder, zero_file_name)
322
+ zero_file_dir = os.path.dirname(zero_file_path)
323
+ if not os.path.exists(zero_file_dir):
324
+ os.makedirs(zero_file_dir, exist_ok=True)
325
+ with open(zero_file_path, "wb") as f:
326
+ f.write(b"")
327
+ print(f"Node-{node_rank} handled {len(zero_files)} files with size 0 bytes, time taken: {time.time() - start_time:.2f}s")
328
+
329
+ print(f"Node-{node_rank} finished syncing all files, time taken: {time.time() - start_time:.2f}s")
330
+
331
+
332
+ if __name__ == "__main__":
333
+ import argparse
334
+
335
+ parser = argparse.ArgumentParser(description="Addf's tool")
336
+ add_args(parser)
337
+ args = parser.parse_args()
338
+ if args.hostfile:
339
+ sync_main(args)
340
+ else:
341
+ sync_worker(args)
@@ -1,7 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: addftool
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Requires-Dist: cryptography
5
5
  Requires-Dist: requests
6
6
  Requires-Dist: PyYAML
7
7
  Requires-Dist: psutil
8
+ Requires-Dist: fabric
@@ -1,17 +1,17 @@
1
1
  addftool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  addftool/addf_portal.py,sha256=w2LgsoutfnrKhtrQAXouUMwLqnsp5ALlsBYUWg8n9NM,781
3
3
  addftool/blob.py,sha256=NZOItDyFUIdV1tfhJZJJBEzGy296CE5NCictTzP4OPc,8282
4
+ addftool/sync.py,sha256=ZpYxbM8uiPFrV7ODmOaM7asVPCWaxBixA-arVc-1kfs,14045
4
5
  addftool/tool.py,sha256=EuKQ2t2InN7yB-_oYLcdsA7vRqzRGTunwIxplUSqEG0,2054
5
6
  addftool/util.py,sha256=zlNLu8Be8cGIpNRqBw8_0q7nFxWlsJ9cToN62ohjdXE,2335
6
7
  addftool/deploy/__init__.py,sha256=tpyoTh3SqAQojPizsJDvQohu1Pcb3-w-DP5sO4-5lBM,1220
7
- addftool/deploy/azure.py,sha256=UQR1hOEYUtsm2fbWBczsnEB_mh7yUuN2NDv3sgMMsac,1246
8
+ addftool/deploy/azure.py,sha256=_o_9Eh8cVwLDAqvfyRYBtQRHs_Gul-nCs2ZXttwO1bk,1301
8
9
  addftool/deploy/ssh_server.py,sha256=f2T8fgwACVljPfdcimMywUjsFnLCWRde7iWPAILpRz8,5463
9
10
  addftool/process/__init__.py,sha256=gPdGsjMEET6crzOz4Iw5cmf6RR1toXGovydRXv8Uagk,3543
10
11
  addftool/process/utils.py,sha256=me4HqMz5OgRcQMUJmVhKdTJh4SW5BB-pd_lq7g8-UwE,2252
11
12
  addftool/ssh/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- addftool/sync/__init__.py,sha256=wOqFCOA51rFUttBjOO44W3Fc66mhX5ir2R89lsO6gR0,1702
13
- addftool-0.1.6.dist-info/METADATA,sha256=SQSnBFDxmD836XQAgZDtqvpzT_Z1LewNKTRCfdMC5ag,148
14
- addftool-0.1.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
15
- addftool-0.1.6.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
16
- addftool-0.1.6.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
17
- addftool-0.1.6.dist-info/RECORD,,
13
+ addftool-0.1.8.dist-info/METADATA,sha256=f50lOq51j55hNh2hnk6SdAni0E7MXHec81sBCOHZ_ro,170
14
+ addftool-0.1.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
15
+ addftool-0.1.8.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
16
+ addftool-0.1.8.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
17
+ addftool-0.1.8.dist-info/RECORD,,
addftool/sync/__init__.py DELETED
@@ -1,42 +0,0 @@
1
- import os
2
-
3
-
4
- def add_sync_args(subparsers):
5
- process_killer_parser = subparsers.add_parser('sync', help='download and sync folder from master node to other nodes')
6
-
7
- process_killer_parser.add_argument("--from_blob_url", help="download from blob url to master node before sync", type=str, default="")
8
- process_killer_parser.add_argument("--sas_token", help="sas token for blob url", type=str, default="")
9
- process_killer_parser.add_argument("--tool", help="tool name", type=str, default="torch_nccl", choices=["torch_nccl", "rsync"])
10
- process_killer_parser.add_argument("--hostfile", help="host file, sync file from node-0 to others", type=str, default="")
11
-
12
- # distributed downloader from blob
13
- process_killer_parser.add_argument("--donwload_nodes", help="download nodes, default is node-0", type=int, default=1)
14
-
15
- process_killer_parser.add_argument("folder", nargs='?', help="the folder need to sync", type=str, default="")
16
-
17
-
18
- def sync_main(args):
19
- print(args)
20
- exit(0)
21
- if args.source == "" or args.target == "":
22
- print("Please provide source and target folder")
23
- return
24
-
25
- # check if source is a folder
26
- if not os.path.isdir(args.source):
27
- print(f"Source {args.source} is not a folder")
28
- return
29
-
30
- # check if target is a folder
31
- if not os.path.isdir(args.target):
32
- print(f"Target {args.target} is not a folder")
33
- return
34
-
35
- # check if source and target are the same
36
- if os.path.abspath(args.source) == os.path.abspath(args.target):
37
- print(f"Source and target are the same")
38
- return
39
-
40
- # sync source to target
41
- command = f"rsync -avz --delete {args.source} {args.target}"
42
- os.system(command)