genarena 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. genarena/__init__.py +49 -2
  2. genarena/__main__.py +10 -0
  3. genarena/arena.py +1685 -0
  4. genarena/battle.py +337 -0
  5. genarena/bt_elo.py +507 -0
  6. genarena/cli.py +1581 -0
  7. genarena/data.py +476 -0
  8. genarena/deploy/Dockerfile +22 -0
  9. genarena/deploy/README.md +55 -0
  10. genarena/deploy/__init__.py +5 -0
  11. genarena/deploy/app.py +84 -0
  12. genarena/experiments.py +121 -0
  13. genarena/leaderboard.py +270 -0
  14. genarena/logs.py +409 -0
  15. genarena/models.py +412 -0
  16. genarena/prompts/__init__.py +127 -0
  17. genarena/prompts/mmrb2.py +373 -0
  18. genarena/sampling.py +336 -0
  19. genarena/state.py +656 -0
  20. genarena/sync/__init__.py +105 -0
  21. genarena/sync/auto_commit.py +118 -0
  22. genarena/sync/deploy_ops.py +543 -0
  23. genarena/sync/git_ops.py +422 -0
  24. genarena/sync/hf_ops.py +891 -0
  25. genarena/sync/init_ops.py +431 -0
  26. genarena/sync/packer.py +587 -0
  27. genarena/sync/submit.py +837 -0
  28. genarena/utils.py +103 -0
  29. genarena/validation/__init__.py +19 -0
  30. genarena/validation/schema.py +327 -0
  31. genarena/validation/validator.py +329 -0
  32. genarena/visualize/README.md +148 -0
  33. genarena/visualize/__init__.py +14 -0
  34. genarena/visualize/app.py +938 -0
  35. genarena/visualize/data_loader.py +2430 -0
  36. genarena/visualize/static/app.js +3762 -0
  37. genarena/visualize/static/model_aliases.json +86 -0
  38. genarena/visualize/static/style.css +4104 -0
  39. genarena/visualize/templates/index.html +413 -0
  40. genarena/vlm.py +519 -0
  41. genarena-0.1.1.dist-info/METADATA +178 -0
  42. genarena-0.1.1.dist-info/RECORD +44 -0
  43. {genarena-0.0.1.dist-info → genarena-0.1.1.dist-info}/WHEEL +1 -2
  44. genarena-0.1.1.dist-info/entry_points.txt +2 -0
  45. genarena-0.0.1.dist-info/METADATA +0 -26
  46. genarena-0.0.1.dist-info/RECORD +0 -5
  47. genarena-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,891 @@
1
+ # Copyright 2026 Ruihang Li.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ # See LICENSE file in the project root for details.
4
+
5
+ """
6
+ Huggingface operations module for GenArena.
7
+
8
+ This module provides functionality for uploading and downloading
9
+ arena data to/from Huggingface Dataset repositories.
10
+ """
11
+
12
+ import logging
13
+ import os
14
+ import time
15
+ import functools
16
+ from typing import Any, Callable, Optional, TypeVar
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Type variable for retry decorator
21
+ T = TypeVar("T")
22
+
23
+ # Default retry configuration
24
+ DEFAULT_MAX_RETRIES = 3
25
+ DEFAULT_RETRY_DELAY = 2.0
26
+ DEFAULT_RETRY_BACKOFF = 2.0 # Exponential backoff multiplier
27
+
28
+
29
+ def retry_on_failure(
30
+ max_retries: int = DEFAULT_MAX_RETRIES,
31
+ delay: float = DEFAULT_RETRY_DELAY,
32
+ backoff: float = DEFAULT_RETRY_BACKOFF,
33
+ exceptions: tuple = (Exception,),
34
+ ) -> Callable:
35
+ """
36
+ Decorator that retries a function on failure with exponential backoff.
37
+
38
+ Args:
39
+ max_retries: Maximum number of retry attempts
40
+ delay: Initial delay between retries in seconds
41
+ backoff: Multiplier for delay after each retry
42
+ exceptions: Tuple of exception types to catch and retry
43
+
44
+ Returns:
45
+ Decorated function
46
+ """
47
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
48
+ @functools.wraps(func)
49
+ def wrapper(*args: Any, **kwargs: Any) -> T:
50
+ current_delay = delay
51
+ last_exception = None
52
+
53
+ for attempt in range(max_retries + 1):
54
+ try:
55
+ return func(*args, **kwargs)
56
+ except exceptions as e:
57
+ last_exception = e
58
+ if attempt < max_retries:
59
+ logger.warning(
60
+ f"{func.__name__} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
61
+ f"Retrying in {current_delay:.1f}s..."
62
+ )
63
+ time.sleep(current_delay)
64
+ current_delay *= backoff
65
+ else:
66
+ logger.error(
67
+ f"{func.__name__} failed after {max_retries + 1} attempts: {e}"
68
+ )
69
+
70
+ # Re-raise the last exception
71
+ raise last_exception # type: ignore
72
+
73
+ return wrapper
74
+ return decorator
75
+
76
+ # Environment variable for HF token
77
+ HF_TOKEN_ENV = "HF_TOKEN"
78
+
79
+
80
+ def get_hf_token() -> Optional[str]:
81
+ """
82
+ Get the Huggingface token from environment variable.
83
+
84
+ Returns:
85
+ Token string or None if not set
86
+ """
87
+ return os.environ.get(HF_TOKEN_ENV)
88
+
89
+
90
+ def require_hf_token() -> str:
91
+ """
92
+ Get the Huggingface token, raising an error if not set.
93
+
94
+ Returns:
95
+ Token string
96
+
97
+ Raises:
98
+ ValueError: If HF_TOKEN environment variable is not set
99
+ """
100
+ token = get_hf_token()
101
+ if not token:
102
+ raise ValueError(
103
+ f"Environment variable {HF_TOKEN_ENV} is not set. "
104
+ f"Please set it with your Huggingface token: "
105
+ f"export {HF_TOKEN_ENV}='your_token_here'"
106
+ )
107
+ return token
108
+
109
+
110
+ def validate_dataset_repo(repo_id: str, token: Optional[str] = None) -> tuple[bool, str]:
111
+ """
112
+ Validate that a repository exists and is a Dataset type.
113
+
114
+ Args:
115
+ repo_id: Repository ID (e.g., "username/repo-name")
116
+ token: Huggingface token (optional for public repos)
117
+
118
+ Returns:
119
+ Tuple of (is_valid, message)
120
+ """
121
+ try:
122
+ from huggingface_hub import HfApi
123
+ from huggingface_hub.utils import RepositoryNotFoundError
124
+
125
+ api = HfApi(token=token)
126
+
127
+ try:
128
+ repo_info = api.repo_info(repo_id=repo_id, repo_type="dataset")
129
+ return True, f"Valid Dataset repository: {repo_id}"
130
+ except RepositoryNotFoundError:
131
+ # Try to check if it exists as a different type
132
+ try:
133
+ # Check if it's a model repo
134
+ api.repo_info(repo_id=repo_id, repo_type="model")
135
+ return False, (
136
+ f"Repository '{repo_id}' exists but is a Model repository, not a Dataset. "
137
+ f"Please create a Dataset repository on Huggingface."
138
+ )
139
+ except RepositoryNotFoundError:
140
+ pass
141
+
142
+ try:
143
+ # Check if it's a space repo
144
+ api.repo_info(repo_id=repo_id, repo_type="space")
145
+ return False, (
146
+ f"Repository '{repo_id}' exists but is a Space repository, not a Dataset. "
147
+ f"Please create a Dataset repository on Huggingface."
148
+ )
149
+ except RepositoryNotFoundError:
150
+ pass
151
+
152
+ return False, (
153
+ f"Repository '{repo_id}' does not exist. "
154
+ f"Please create a Dataset repository on Huggingface first: "
155
+ f"https://huggingface.co/new-dataset"
156
+ )
157
+
158
+ except ImportError:
159
+ return False, (
160
+ "huggingface_hub package is not installed. "
161
+ "Please install it with: pip install huggingface_hub"
162
+ )
163
+ except Exception as e:
164
+ return False, f"Error validating repository: {e}"
165
+
166
+
167
+ def list_repo_files(
168
+ repo_id: str,
169
+ token: Optional[str] = None,
170
+ revision: str = "main",
171
+ ) -> tuple[bool, list[str], str]:
172
+ """
173
+ List all files in a Huggingface Dataset repository.
174
+
175
+ Args:
176
+ repo_id: Repository ID
177
+ token: Huggingface token (optional for public repos)
178
+ revision: Branch/revision name
179
+
180
+ Returns:
181
+ Tuple of (success, file_list, message)
182
+ """
183
+ try:
184
+ from huggingface_hub import HfApi
185
+
186
+ api = HfApi(token=token)
187
+
188
+ files = api.list_repo_files(
189
+ repo_id=repo_id,
190
+ repo_type="dataset",
191
+ revision=revision,
192
+ )
193
+
194
+ return True, list(files), f"Found {len(files)} files in {repo_id}"
195
+
196
+ except Exception as e:
197
+ return False, [], f"Error listing repository files: {e}"
198
+
199
+
200
+ def get_repo_file_info(
201
+ repo_id: str,
202
+ token: Optional[str] = None,
203
+ revision: str = "main",
204
+ ) -> tuple[bool, list[dict], str]:
205
+ """
206
+ Get detailed file information from a Huggingface Dataset repository.
207
+
208
+ Args:
209
+ repo_id: Repository ID
210
+ token: Huggingface token (optional for public repos)
211
+ revision: Branch/revision name
212
+
213
+ Returns:
214
+ Tuple of (success, file_info_list, message)
215
+ """
216
+ try:
217
+ from huggingface_hub import HfApi
218
+
219
+ api = HfApi(token=token)
220
+
221
+ repo_info = api.repo_info(
222
+ repo_id=repo_id,
223
+ repo_type="dataset",
224
+ revision=revision,
225
+ files_metadata=True,
226
+ )
227
+
228
+ file_infos = []
229
+ if repo_info.siblings:
230
+ for sibling in repo_info.siblings:
231
+ file_infos.append({
232
+ "path": sibling.rfilename,
233
+ "size": sibling.size,
234
+ "blob_id": sibling.blob_id,
235
+ })
236
+
237
+ return True, file_infos, f"Found {len(file_infos)} files in {repo_id}"
238
+
239
+ except Exception as e:
240
+ return False, [], f"Error getting repository info: {e}"
241
+
242
+
243
+ def upload_file(
244
+ repo_id: str,
245
+ local_path: str,
246
+ remote_path: str,
247
+ token: str,
248
+ commit_message: Optional[str] = None,
249
+ max_retries: int = DEFAULT_MAX_RETRIES,
250
+ repo_type: str = "dataset",
251
+ ) -> tuple[bool, str]:
252
+ """
253
+ Upload a single file to a Huggingface repository with retry support.
254
+
255
+ Args:
256
+ repo_id: Repository ID
257
+ local_path: Local file path
258
+ remote_path: Path in the repository
259
+ token: Huggingface token
260
+ commit_message: Optional commit message
261
+ max_retries: Maximum number of retry attempts on failure
262
+ repo_type: Repository type ("dataset", "model", or "space")
263
+
264
+ Returns:
265
+ Tuple of (success, message)
266
+ """
267
+ from huggingface_hub import HfApi
268
+
269
+ api = HfApi(token=token)
270
+
271
+ if not commit_message:
272
+ commit_message = f"Upload {remote_path}"
273
+
274
+ @retry_on_failure(
275
+ max_retries=max_retries,
276
+ delay=DEFAULT_RETRY_DELAY,
277
+ backoff=DEFAULT_RETRY_BACKOFF,
278
+ )
279
+ def _do_upload() -> None:
280
+ api.upload_file(
281
+ path_or_fileobj=local_path,
282
+ path_in_repo=remote_path,
283
+ repo_id=repo_id,
284
+ repo_type=repo_type,
285
+ commit_message=commit_message,
286
+ )
287
+
288
+ try:
289
+ _do_upload()
290
+ return True, f"Uploaded {remote_path}"
291
+ except Exception as e:
292
+ return False, f"Error uploading file: {e}"
293
+
294
+
295
+ def upload_files_batch(
296
+ repo_id: str,
297
+ file_mappings: list[tuple[str, str]],
298
+ token: str,
299
+ commit_message: Optional[str] = None,
300
+ ) -> tuple[bool, str]:
301
+ """
302
+ Upload multiple files in a single commit.
303
+
304
+ Args:
305
+ repo_id: Repository ID
306
+ file_mappings: List of (local_path, remote_path) tuples
307
+ token: Huggingface token
308
+ commit_message: Optional commit message
309
+
310
+ Returns:
311
+ Tuple of (success, message)
312
+ """
313
+ try:
314
+ from huggingface_hub import HfApi, CommitOperationAdd
315
+
316
+ api = HfApi(token=token)
317
+
318
+ if not commit_message:
319
+ commit_message = f"Upload {len(file_mappings)} files"
320
+
321
+ operations = [
322
+ CommitOperationAdd(
323
+ path_in_repo=remote_path,
324
+ path_or_fileobj=local_path,
325
+ )
326
+ for local_path, remote_path in file_mappings
327
+ ]
328
+
329
+ api.create_commit(
330
+ repo_id=repo_id,
331
+ repo_type="dataset",
332
+ operations=operations,
333
+ commit_message=commit_message,
334
+ )
335
+
336
+ return True, f"Uploaded {len(file_mappings)} files"
337
+
338
+ except Exception as e:
339
+ return False, f"Error uploading files: {e}"
340
+
341
+
342
+ def download_file(
343
+ repo_id: str,
344
+ remote_path: str,
345
+ local_path: str,
346
+ token: Optional[str] = None,
347
+ revision: str = "main",
348
+ ) -> tuple[bool, str]:
349
+ """
350
+ Download a single file from a Huggingface Dataset repository.
351
+
352
+ Args:
353
+ repo_id: Repository ID
354
+ remote_path: Path in the repository
355
+ local_path: Local file path to save to
356
+ token: Huggingface token (optional for public repos)
357
+ revision: Branch/revision name
358
+
359
+ Returns:
360
+ Tuple of (success, message)
361
+ """
362
+ try:
363
+ from huggingface_hub import hf_hub_download
364
+
365
+ # Ensure local directory exists
366
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
367
+
368
+ # Download to a temp location first, then move
369
+ downloaded_path = hf_hub_download(
370
+ repo_id=repo_id,
371
+ filename=remote_path,
372
+ repo_type="dataset",
373
+ revision=revision,
374
+ token=token,
375
+ local_dir=os.path.dirname(local_path),
376
+ local_dir_use_symlinks=False,
377
+ )
378
+
379
+ # If downloaded to a different path, copy to expected location
380
+ if downloaded_path != local_path:
381
+ import shutil
382
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
383
+ shutil.copy2(downloaded_path, local_path)
384
+
385
+ return True, f"Downloaded {remote_path}"
386
+
387
+ except Exception as e:
388
+ return False, f"Error downloading file: {e}"
389
+
390
+
391
+ def check_file_exists(
392
+ repo_id: str,
393
+ remote_path: str,
394
+ token: Optional[str] = None,
395
+ revision: str = "main",
396
+ ) -> bool:
397
+ """
398
+ Check if a file exists in the repository.
399
+
400
+ Args:
401
+ repo_id: Repository ID
402
+ remote_path: Path in the repository
403
+ token: Huggingface token (optional for public repos)
404
+ revision: Branch/revision name
405
+
406
+ Returns:
407
+ True if file exists
408
+ """
409
+ try:
410
+ from huggingface_hub import HfApi
411
+
412
+ api = HfApi(token=token)
413
+ files = api.list_repo_files(
414
+ repo_id=repo_id,
415
+ repo_type="dataset",
416
+ revision=revision,
417
+ )
418
+
419
+ return remote_path in files
420
+
421
+ except Exception:
422
+ return False
423
+
424
+
425
+ def format_file_size(size_bytes: Optional[int]) -> str:
426
+ """
427
+ Format file size in human-readable format.
428
+
429
+ Args:
430
+ size_bytes: Size in bytes
431
+
432
+ Returns:
433
+ Human-readable size string
434
+ """
435
+ if size_bytes is None:
436
+ return "Unknown"
437
+
438
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
439
+ if abs(size_bytes) < 1024.0:
440
+ return f"{size_bytes:.1f} {unit}"
441
+ size_bytes /= 1024.0
442
+
443
+ return f"{size_bytes:.1f} PB"
444
+
445
+
446
+ # =============================================================================
447
+ # High-level operations
448
+ # =============================================================================
449
+
450
+ def upload_arena_data(
451
+ arena_dir: str,
452
+ repo_id: str,
453
+ subsets: Optional[list[str]] = None,
454
+ models: Optional[list[str]] = None,
455
+ experiments: Optional[list[str]] = None,
456
+ overwrite: bool = False,
457
+ show_progress: bool = True,
458
+ max_retries: int = DEFAULT_MAX_RETRIES,
459
+ ) -> tuple[bool, str]:
460
+ """
461
+ Upload arena data to a Huggingface Dataset repository.
462
+
463
+ This function:
464
+ 1. Validates the repository exists and is a Dataset type
465
+ 2. Collects files to upload based on filters
466
+ 3. Packs directories into ZIP files
467
+ 4. Uploads files with progress indication and retry on failure
468
+
469
+ Supports resume upload: by default (overwrite=False), already uploaded files
470
+ are automatically skipped, enabling resumable uploads after connection failures.
471
+
472
+ Args:
473
+ arena_dir: Path to the arena directory
474
+ repo_id: Huggingface repository ID
475
+ subsets: List of subsets to upload (None = all)
476
+ models: List of models to upload (None = all)
477
+ experiments: List of experiments (exp_name) to upload (None = all)
478
+ overwrite: If True, overwrite existing files; if False, skip existing (resume mode)
479
+ show_progress: If True, show progress bar
480
+ max_retries: Maximum number of retry attempts per file on failure
481
+
482
+ Returns:
483
+ Tuple of (success, message)
484
+ """
485
+ from genarena.sync.packer import (
486
+ collect_upload_tasks,
487
+ pack_model_dir,
488
+ pack_exp_dir,
489
+ TempPackingContext,
490
+ TaskType,
491
+ )
492
+
493
+ # Get token
494
+ try:
495
+ token = require_hf_token()
496
+ except ValueError as e:
497
+ return False, str(e)
498
+
499
+ # Validate repository
500
+ valid, msg = validate_dataset_repo(repo_id, token)
501
+ if not valid:
502
+ return False, msg
503
+
504
+ logger.info(f"Uploading to repository: {repo_id}")
505
+
506
+ # Collect upload tasks
507
+ tasks = collect_upload_tasks(arena_dir, subsets, models, experiments)
508
+ if not tasks:
509
+ return False, "No files to upload. Check arena_dir and filters."
510
+
511
+ logger.info(f"Found {len(tasks)} items to scan")
512
+
513
+ # Get existing files in repo (for overwrite check)
514
+ existing_files = set()
515
+ if not overwrite:
516
+ logger.info("Checking existing files in remote repository...")
517
+ success, files, _ = list_repo_files(repo_id, token)
518
+ if success:
519
+ existing_files = set(files)
520
+ logger.info(f"Found {len(existing_files)} files in remote repository")
521
+
522
+ # Pre-scan: categorize tasks into to_upload and to_skip
523
+ to_upload = []
524
+ to_skip = []
525
+ for task in tasks:
526
+ if not overwrite and task.remote_path in existing_files:
527
+ to_skip.append(task)
528
+ else:
529
+ to_upload.append(task)
530
+
531
+ # Display scan summary
532
+ logger.info(f"Scan summary: {len(to_upload)} to upload, {len(to_skip)} already exist (will skip)")
533
+
534
+ if to_skip:
535
+ logger.info("Already uploaded (will skip):")
536
+ for task in to_skip[:10]:
537
+ logger.info(f" ✓ {task.remote_path}")
538
+ if len(to_skip) > 10:
539
+ logger.info(f" ... and {len(to_skip) - 10} more")
540
+
541
+ if to_upload:
542
+ logger.info("To be uploaded:")
543
+ for task in to_upload[:10]:
544
+ logger.info(f" → {task.remote_path}")
545
+ if len(to_upload) > 10:
546
+ logger.info(f" ... and {len(to_upload) - 10} more")
547
+
548
+ if not to_upload:
549
+ return True, f"All {len(to_skip)} files already exist in repository. Nothing to upload."
550
+
551
+ # Process tasks (only those that need uploading)
552
+ uploaded = 0
553
+ skipped = len(to_skip) # Pre-count skipped
554
+ failed = 0
555
+ errors = []
556
+
557
+ # Setup progress bar
558
+ if show_progress:
559
+ try:
560
+ from tqdm import tqdm
561
+ to_upload = tqdm(to_upload, desc="Uploading", unit="file")
562
+ except ImportError:
563
+ pass
564
+
565
+ with TempPackingContext() as ctx:
566
+ for task in to_upload:
567
+ try:
568
+ if task.task_type == TaskType.MODEL_ZIP:
569
+ # Pack model directory
570
+ zip_path = ctx.get_temp_zip_path(task.remote_path)
571
+ success, msg = pack_model_dir(task.local_path, zip_path)
572
+ if not success:
573
+ errors.append(f"{task.name}: {msg}")
574
+ failed += 1
575
+ continue
576
+
577
+ # Upload ZIP with retry
578
+ success, msg = upload_file(
579
+ repo_id, zip_path, task.remote_path, token,
580
+ commit_message=f"[genarena] Upload model: {task.subset}/{task.name}",
581
+ max_retries=max_retries,
582
+ )
583
+
584
+ elif task.task_type == TaskType.EXP_ZIP:
585
+ # Pack experiment directory
586
+ zip_path = ctx.get_temp_zip_path(task.remote_path)
587
+ success, msg = pack_exp_dir(task.local_path, zip_path)
588
+ if not success:
589
+ errors.append(f"{task.name}: {msg}")
590
+ failed += 1
591
+ continue
592
+
593
+ # Upload ZIP with retry
594
+ success, msg = upload_file(
595
+ repo_id, zip_path, task.remote_path, token,
596
+ commit_message=f"[genarena] Upload experiment: {task.subset}/{task.name}",
597
+ max_retries=max_retries,
598
+ )
599
+
600
+ elif task.task_type == TaskType.SMALL_FILE:
601
+ # Upload small file directly with retry
602
+ success, msg = upload_file(
603
+ repo_id, task.local_path, task.remote_path, token,
604
+ commit_message=f"[genarena] Upload {task.name}",
605
+ max_retries=max_retries,
606
+ )
607
+
608
+ else:
609
+ success = False
610
+ msg = f"Unknown task type: {task.task_type}"
611
+
612
+ if success:
613
+ uploaded += 1
614
+ logger.debug(f"Uploaded: {task.remote_path}")
615
+ else:
616
+ errors.append(f"{task.name}: {msg}")
617
+ failed += 1
618
+
619
+ except Exception as e:
620
+ errors.append(f"{task.name}: {e}")
621
+ failed += 1
622
+
623
+ # Summary
624
+ summary = f"Uploaded: {uploaded}, Skipped: {skipped}, Failed: {failed}"
625
+ if errors:
626
+ summary += f"\nErrors:\n" + "\n".join(f" - {e}" for e in errors[:5])
627
+ if len(errors) > 5:
628
+ summary += f"\n ... and {len(errors) - 5} more errors"
629
+
630
+ repo_url = f"https://huggingface.co/datasets/{repo_id}"
631
+ summary += f"\n\nRepository URL: {repo_url}"
632
+
633
+ success = failed == 0 or uploaded > 0
634
+ return success, summary
635
+
636
+
637
+ def pull_arena_data(
638
+ arena_dir: str,
639
+ repo_id: str,
640
+ subsets: Optional[list[str]] = None,
641
+ models: Optional[list[str]] = None,
642
+ experiments: Optional[list[str]] = None,
643
+ revision: str = "main",
644
+ overwrite: bool = False,
645
+ show_progress: bool = True,
646
+ ) -> tuple[bool, str]:
647
+ """
648
+ Pull arena data from a Huggingface Dataset repository.
649
+
650
+ This function:
651
+ 1. Validates the repository exists and is a Dataset type
652
+ 2. Lists files in the repository
653
+ 3. Filters based on subsets/models
654
+ 4. Downloads and unpacks ZIP files
655
+
656
+ Args:
657
+ arena_dir: Path to the local arena directory
658
+ repo_id: Huggingface repository ID
659
+ subsets: List of subsets to download (None = all)
660
+ models: List of models to download (None = all)
661
+ experiments: List of experiments (exp_name) to download (None = all)
662
+ revision: Branch/revision to download from
663
+ overwrite: If True, overwrite existing files
664
+ show_progress: If True, show progress bar
665
+
666
+ Returns:
667
+ Tuple of (success, message)
668
+ """
669
+ import tempfile
670
+ import shutil
671
+ from genarena.sync.packer import (
672
+ collect_download_tasks,
673
+ unpack_zip,
674
+ TaskType,
675
+ )
676
+
677
+ # Get token (optional for public repos)
678
+ token = get_hf_token()
679
+
680
+ # Validate repository
681
+ valid, msg = validate_dataset_repo(repo_id, token)
682
+ if not valid:
683
+ return False, msg
684
+
685
+ logger.info(f"Pulling from repository: {repo_id} (revision: {revision})")
686
+
687
+ # List files in repository
688
+ success, repo_files, msg = list_repo_files(repo_id, token, revision)
689
+ if not success:
690
+ return False, msg
691
+
692
+ if not repo_files:
693
+ return False, "Repository is empty"
694
+
695
+ # Collect download tasks
696
+ tasks = collect_download_tasks(repo_files, arena_dir, subsets, models, experiments)
697
+ if not tasks:
698
+ return False, "No matching files to download. Check filters."
699
+
700
+ logger.info(f"Found {len(tasks)} items to download")
701
+
702
+ # Process tasks
703
+ downloaded = 0
704
+ skipped = 0
705
+ failed = 0
706
+ errors = []
707
+
708
+ # Setup progress bar
709
+ if show_progress:
710
+ try:
711
+ from tqdm import tqdm
712
+ tasks = tqdm(tasks, desc="Downloading", unit="file")
713
+ except ImportError:
714
+ pass
715
+
716
+ # Create temp directory for downloads
717
+ temp_dir = tempfile.mkdtemp(prefix="genarena_pull_")
718
+
719
+ try:
720
+ for task in tasks:
721
+ try:
722
+ if task.task_type in (TaskType.MODEL_ZIP, TaskType.EXP_ZIP):
723
+ # Download ZIP to temp location
724
+ temp_zip = os.path.join(temp_dir, os.path.basename(task.remote_path))
725
+ success, msg = download_file(
726
+ repo_id, task.remote_path, temp_zip, token, revision
727
+ )
728
+
729
+ if not success:
730
+ errors.append(f"{task.name}: {msg}")
731
+ failed += 1
732
+ continue
733
+
734
+ # Unpack ZIP
735
+ success, msg = unpack_zip(temp_zip, task.local_path, overwrite)
736
+ if not success:
737
+ errors.append(f"{task.name}: {msg}")
738
+ failed += 1
739
+ continue
740
+
741
+ downloaded += 1
742
+ logger.debug(f"Downloaded and unpacked: {task.remote_path}")
743
+
744
+ elif task.task_type == TaskType.SMALL_FILE:
745
+ # Check if file exists and skip if not overwriting
746
+ if os.path.exists(task.local_path) and not overwrite:
747
+ logger.debug(f"Skipping existing: {task.local_path}")
748
+ skipped += 1
749
+ continue
750
+
751
+ # Download file directly
752
+ success, msg = download_file(
753
+ repo_id, task.remote_path, task.local_path, token, revision
754
+ )
755
+
756
+ if success:
757
+ downloaded += 1
758
+ logger.debug(f"Downloaded: {task.remote_path}")
759
+ else:
760
+ errors.append(f"{task.name}: {msg}")
761
+ failed += 1
762
+
763
+ except Exception as e:
764
+ errors.append(f"{task.name}: {e}")
765
+ failed += 1
766
+
767
+ finally:
768
+ # Cleanup temp directory
769
+ shutil.rmtree(temp_dir, ignore_errors=True)
770
+
771
+ # Summary
772
+ summary = f"Downloaded: {downloaded}, Skipped: {skipped}, Failed: {failed}"
773
+ if errors:
774
+ summary += f"\nErrors:\n" + "\n".join(f" - {e}" for e in errors[:5])
775
+ if len(errors) > 5:
776
+ summary += f"\n ... and {len(errors) - 5} more errors"
777
+
778
+ success = failed == 0 or downloaded > 0
779
+ return success, summary
780
+
781
+
782
+ def list_repo_contents(
783
+ repo_id: str,
784
+ revision: str = "main",
785
+ ) -> tuple[bool, str]:
786
+ """
787
+ List contents of a Huggingface Dataset repository.
788
+
789
+ Displays files organized by subset with size information.
790
+
791
+ Args:
792
+ repo_id: Huggingface repository ID
793
+ revision: Branch/revision name
794
+
795
+ Returns:
796
+ Tuple of (success, formatted_output)
797
+ """
798
+ # Get token (optional for public repos)
799
+ token = get_hf_token()
800
+
801
+ # Validate repository
802
+ valid, msg = validate_dataset_repo(repo_id, token)
803
+ if not valid:
804
+ return False, msg
805
+
806
+ # Get file info
807
+ success, file_infos, msg = get_repo_file_info(repo_id, token, revision)
808
+ if not success:
809
+ return False, msg
810
+
811
+ if not file_infos:
812
+ return True, f"Repository '{repo_id}' is empty"
813
+
814
+ # Organize by subset
815
+ subsets: dict[str, list[dict]] = {}
816
+ other_files: list[dict] = []
817
+
818
+ for info in file_infos:
819
+ path = info["path"]
820
+ parts = path.split("/")
821
+
822
+ if len(parts) >= 2:
823
+ subset = parts[0]
824
+ if subset not in subsets:
825
+ subsets[subset] = []
826
+ subsets[subset].append(info)
827
+ else:
828
+ other_files.append(info)
829
+
830
+ # Format output
831
+ lines = [
832
+ f"Repository: {repo_id}",
833
+ f"Revision: {revision}",
834
+ f"Total files: {len(file_infos)}",
835
+ "",
836
+ ]
837
+
838
+ total_size = sum(f.get("size", 0) or 0 for f in file_infos)
839
+ lines.append(f"Total size: {format_file_size(total_size)}")
840
+ lines.append("")
841
+
842
+ for subset in sorted(subsets.keys()):
843
+ files = subsets[subset]
844
+ subset_size = sum(f.get("size", 0) or 0 for f in files)
845
+
846
+ lines.append(f"=== {subset} ({len(files)} files, {format_file_size(subset_size)}) ===")
847
+
848
+ # Organize by type
849
+ models = []
850
+ experiments = []
851
+ other = []
852
+
853
+ for f in files:
854
+ path = f["path"]
855
+ if "/models/" in path:
856
+ models.append(f)
857
+ elif "/pk_logs/" in path:
858
+ experiments.append(f)
859
+ else:
860
+ other.append(f)
861
+
862
+ if models:
863
+ lines.append(" Models:")
864
+ for f in sorted(models, key=lambda x: x["path"]):
865
+ size = format_file_size(f.get("size"))
866
+ name = os.path.basename(f["path"])
867
+ lines.append(f" - {name} ({size})")
868
+
869
+ if experiments:
870
+ lines.append(" Experiments:")
871
+ for f in sorted(experiments, key=lambda x: x["path"]):
872
+ size = format_file_size(f.get("size"))
873
+ name = os.path.basename(f["path"])
874
+ lines.append(f" - {name} ({size})")
875
+
876
+ if other:
877
+ lines.append(" Other:")
878
+ for f in sorted(other, key=lambda x: x["path"]):
879
+ size = format_file_size(f.get("size"))
880
+ name = f["path"].split("/", 1)[1] if "/" in f["path"] else f["path"]
881
+ lines.append(f" - {name} ({size})")
882
+
883
+ lines.append("")
884
+
885
+ if other_files:
886
+ lines.append("=== Other files ===")
887
+ for f in sorted(other_files, key=lambda x: x["path"]):
888
+ size = format_file_size(f.get("size"))
889
+ lines.append(f" - {f['path']} ({size})")
890
+
891
+ return True, "\n".join(lines)