addftool 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,9 @@
1
1
  import os
2
2
  import time
3
+ import fnmatch
3
4
  import subprocess
4
5
  import hashlib
6
+ from concurrent.futures import ThreadPoolExecutor
5
7
 
6
8
  from fabric import Connection, ThreadingGroup
7
9
 
@@ -33,6 +35,22 @@ def add_args(parser):
33
35
  parser.add_argument("--transfer_ranks_per_node", type=int, default=8,
34
36
  help="the number of ranks per node to transfer the files, default is 8.")
35
37
 
38
+ parser.add_argument("--contain_md5_files", action='store_true', default=False,
39
+ help="whether to contain the md5 files in the folder, default is False. " \
40
+ "If True, the md5 files will be transferred to the other nodes and verified. " \
41
+ "If False, the md5 files will be ignored.")
42
+
43
+ parser.add_argument("--include-string", type=str, default="",
44
+ help="the string to include the files, default is empty. " \
45
+ "Such as *.py, *.yaml, *.json, \"*.pt;*.pth\" etc. " \
46
+ "Only node-0 will include the files from the folder, " \
47
+ "If empty, will transfer all the files from the node-0's local folder.")
48
+ parser.add_argument("--exclude-string", type=str, default="",
49
+ help="the string to exclude the files, default is empty. " \
50
+ "Such as *.py, *.yaml, *.json, \"*.pt;*.pth\" etc. " \
51
+ "Only node-0 will exclude the files from the folder, " \
52
+ "If empty, will transfer all the files from the node-0's local folder.")
53
+
36
54
  parser.add_argument("--from_blob_url", type=str, default="",
37
55
  help="the blob url to download from, default is empty. " \
38
56
  "Only node-0 will download the files from the blob url, " \
@@ -77,6 +95,52 @@ def get_ip_via_ssh(hostname):
77
95
  return None
78
96
 
79
97
 
98
+ def parallel_check_md5(file_list, expected_md5s):
99
+ """
100
+ Parallel check MD5 checksums for the given files.
101
+
102
+ Args:
103
+ file_list: List of file paths to check
104
+ md5_dir: Directory containing MD5 files
105
+
106
+ Returns:
107
+ True if all MD5 checksums match, False otherwise
108
+ """
109
+ def calculate_md5(file_path):
110
+ # Call md5sum and capture output
111
+ result = subprocess.run(["md5sum", file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
112
+ if result.returncode != 0:
113
+ print(f"Failed to calculate MD5 for {file_path}: {result.stderr}")
114
+ return file_path, None
115
+
116
+ # md5sum output format: "<md5_hash> <file_path>"
117
+ md5_hash = result.stdout.strip().split()[0]
118
+ return file_path, md5_hash
119
+
120
+ # Calculate MD5 checksums in parallel
121
+ with ThreadPoolExecutor(max_workers=8) as executor:
122
+ results = list(executor.map(calculate_md5, file_list))
123
+
124
+ # Check if all MD5s match
125
+ all_match = True
126
+ for file_path, actual_md5 in results:
127
+ if actual_md5 is None:
128
+ all_match = False
129
+ continue
130
+
131
+ if file_path not in expected_md5s:
132
+ print(f"No expected MD5 for {file_path}")
133
+ all_match = False
134
+ continue
135
+
136
+ expected_md5 = expected_md5s[file_path]
137
+ if actual_md5 != expected_md5:
138
+ print(f"MD5 mismatch for {file_path}: expected {expected_md5}, got {actual_md5}")
139
+ all_match = False
140
+
141
+ return all_match
142
+
143
+
80
144
  def broadcast_folder_main(args):
81
145
  with open(args.hostfile, "r") as f:
82
146
  host_list = []
@@ -95,11 +159,17 @@ def broadcast_folder_main(args):
95
159
  for i, host in enumerate(host_list):
96
160
  put_commands = []
97
161
  put_commands.append((__file__, os.path.join(remote_temp_config_dir, "broadcast.py")))
98
- commnads = "NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 "
162
+ commnads = "NCCL_IB_DISABLE=0 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 "
99
163
  if os.environ.get("SAS_TOKEN") is not None and i == 0:
100
164
  commnads += f"SAS_TOKEN=\"{os.environ['SAS_TOKEN']}\" "
101
165
  commnads += f"{args.torchrun_alias} --nproc_per_node={args.transfer_ranks_per_node} --nnodes={len(host_list)} --node_rank={i} --master_addr={master_addr} --master_port={args.port}"
102
166
  commnads += f" {remote_temp_config_dir}/broadcast.py {args.folder} --tool {args.tool} --transfer_ranks_per_node {args.transfer_ranks_per_node} "
167
+ if args.contain_md5_files:
168
+ commnads += " --contain_md5_files"
169
+ if args.include_string:
170
+ commnads += f" --include-string \"{args.include_string}\""
171
+ if args.exclude_string:
172
+ commnads += f" --exclude-string \"{args.exclude_string}\""
103
173
  if args.from_blob_url and i == 0:
104
174
  commnads += f" --from_blob_url {args.from_blob_url}"
105
175
  if args.md5_verify:
@@ -226,20 +296,53 @@ def broadcast_folder_worker(args):
226
296
  file_size_dict = {}
227
297
 
228
298
  if global_rank == 0:
229
- if args.from_blob_url:
230
- raise NotImplementedError("Downloading files from blob storage is not implemented yet.")
231
-
299
+ # Parse include and exclude patterns
300
+ include_patterns = [p.strip() for p in args.include_string.split(";") if p.strip()]
301
+ exclude_patterns = [p.strip() for p in args.exclude_string.split(";") if p.strip()]
302
+
303
+ print(f"Include patterns: {include_patterns}")
304
+ print(f"Exclude patterns: {exclude_patterns}")
305
+
232
306
  if not os.path.exists(args.folder):
233
307
  raise ValueError(f"Folder {args.folder} does not exist.")
308
+
309
+ # Gather and filter files in a single pass
310
+ file_size_dict = {}
234
311
  for root, dirs, files in os.walk(args.folder):
235
312
  for file in files:
236
313
  file_path = os.path.join(root, file)
237
- file_size = os.path.getsize(file_path)
238
- file_size_dict[file_path] = file_size
314
+ file_name = os.path.basename(file_path)
315
+
316
+ # Skip md5 files if not containing them
317
+ if file_name.endswith(".md5") and not args.contain_md5_files:
318
+ continue
319
+
320
+ # Apply include filters first (if any)
321
+ included = not include_patterns # Include by default if no include patterns
322
+ if include_patterns:
323
+ for pattern in include_patterns:
324
+ if fnmatch.fnmatch(file_name, pattern):
325
+ included = True
326
+ break
327
+
328
+ # Then apply exclude filters
329
+ if included and exclude_patterns:
330
+ for pattern in exclude_patterns:
331
+ if fnmatch.fnmatch(file_name, pattern):
332
+ included = False
333
+ break
334
+
335
+ # Add to file dict if passes both filters
336
+ if included:
337
+ file_size_dict[file_path] = os.path.getsize(file_path)
338
+
339
+ print(f"After filtering: {len(file_size_dict)} files selected for transfer")
340
+ if len(include_patterns) > 0 or len(exclude_patterns) > 0:
341
+ print(f"Files selected for transfer: {file_size_dict.keys()}")
239
342
 
240
343
  # sort the file list by size
241
344
  file_list = sorted(file_size_dict.keys(), key=lambda x: file_size_dict[x], reverse=True)
242
- file_size_list = [file_size_dict[file] for file in file_list]
345
+ file_size_list = [file_size_dict[file] for file in file_list]
243
346
  obj_list = [file_list, file_size_list]
244
347
  dist.broadcast_object_list(obj_list, src=0)
245
348
  else:
@@ -251,6 +354,7 @@ def broadcast_folder_worker(args):
251
354
 
252
355
  worker_g = workers_groups[worker_rank]
253
356
  from_rank = global_rank % args.transfer_ranks_per_node
357
+ broadcast_file_list = []
254
358
  for i in range(len(file_list)):
255
359
  if i % args.transfer_ranks_per_node == worker_rank:
256
360
  file_path = file_list[i]
@@ -261,6 +365,7 @@ def broadcast_folder_worker(args):
261
365
  )
262
366
  if global_rank == from_rank:
263
367
  print(f"Group {global_rank} finished broadcasting {file_path}, size: {file_size / (1024 * 1024):.2f} MB, time taken: {time.time() - start_time:.2f}s")
368
+ broadcast_file_list.append(file_path)
264
369
 
265
370
  dist.barrier()
266
371
  for i in range(len(workers_groups)):
@@ -268,6 +373,27 @@ def broadcast_folder_worker(args):
268
373
  dist.destroy_process_group(workers_groups[i])
269
374
  destroy_process_group()
270
375
 
376
+ if args.contain_md5_files and global_rank % args.transfer_ranks_per_node == 0:
377
+ to_verify_files = []
378
+ excepted_md5s = {}
379
+ for file_path in file_list:
380
+ if not file_path.endswith(".md5"):
381
+ md5_file_path = file_path + ".md5"
382
+ if os.path.exists(md5_file_path):
383
+ with open(md5_file_path, "r") as f:
384
+ md5_hash = f.read().strip()
385
+ excepted_md5s[file_path] = md5_hash
386
+ to_verify_files.append(file_path)
387
+ else:
388
+ print(f"MD5 file {md5_file_path} not found, skipping verification.")
389
+
390
+ # Verify MD5 checksums
391
+ if not parallel_check_md5(to_verify_files, excepted_md5s):
392
+ print(f"MD5 verification failed for some files, please check the logs.")
393
+ raise ValueError("MD5 verification failed.")
394
+ else:
395
+ print(f"Rank-{global_rank}: MD5 verification passed for all files.")
396
+
271
397
  print(f"Rank {global_rank} finished broadcasting all files, time taken: {time.time() - start_time:.2f}s")
272
398
 
273
399
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: addftool
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Requires-Dist: cryptography
5
5
  Requires-Dist: requests
6
6
  Requires-Dist: PyYAML
@@ -1,7 +1,7 @@
1
1
  addftool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  addftool/addf_portal.py,sha256=6XjwGs5m2mRVDWVvCPOiqn1NlxDcGBTQ9Kr_0g5RsJc,1130
3
3
  addftool/blob.py,sha256=YT3eZrC9mdYtntDlwEI5MUHFw98wKtBC_pHEdtqvsv4,9206
4
- addftool/broadcast_folder.py,sha256=bEOr-8Q14DhZFS658hab4U-9HvAW3EGluEriilIdEXQ,11976
4
+ addftool/broadcast_folder.py,sha256=X9tvvMT7cCDbfKqE7hUGaFbBOHIecIgD8xcT34Bqb_8,17708
5
5
  addftool/sync.py,sha256=ZpYxbM8uiPFrV7ODmOaM7asVPCWaxBixA-arVc-1kfs,14045
6
6
  addftool/tool.py,sha256=FmxRY3-pP0_Z0zCUAngjmEMmPUruMftg_iUlB1t2TnQ,2001
7
7
  addftool/util.py,sha256=zlNLu8Be8cGIpNRqBw8_0q7nFxWlsJ9cToN62ohjdXE,2335
@@ -12,8 +12,8 @@ addftool/deploy/vscode_server.py,sha256=tLtSvlcK2fEOaw6udWt8dNELVhwv9F59hF5DJJ-1
12
12
  addftool/process/__init__.py,sha256=Dze8OrcyjQlAbPrjE_h8bMi8W4b3OJyZOjTucPrkJvM,3721
13
13
  addftool/process/utils.py,sha256=JldxnwanLJOgxaPgmCJh7SeBRaaj5rFxWWxh1hpsvbA,2609
14
14
  addftool/ssh/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- addftool-0.2.4.dist-info/METADATA,sha256=Mxypih4dkYUmHFyyVBqbZLVf4CYJqapkfHMJ2zCMVh0,170
16
- addftool-0.2.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
- addftool-0.2.4.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
18
- addftool-0.2.4.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
19
- addftool-0.2.4.dist-info/RECORD,,
15
+ addftool-0.2.5.dist-info/METADATA,sha256=e0fYxOCuq9cps5H7TZS-5wxSN3sqrs9e6g2acvJcMSw,170
16
+ addftool-0.2.5.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
17
+ addftool-0.2.5.dist-info/entry_points.txt,sha256=9lkmuWMInwUAtev8w8poNkNd7iML9Bjd5CBCFVxg2b8,111
18
+ addftool-0.2.5.dist-info/top_level.txt,sha256=jqj56-plrBbyzY0tIxB6wPzjAA8kte4hUlajyyQygN4,9
19
+ addftool-0.2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5