ml-dash 0.0.11__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. ml_dash/__init__.py +59 -1
  2. ml_dash/auto_start.py +42 -0
  3. ml_dash/cli.py +67 -0
  4. ml_dash/cli_commands/__init__.py +1 -0
  5. ml_dash/cli_commands/download.py +797 -0
  6. ml_dash/cli_commands/list.py +343 -0
  7. ml_dash/cli_commands/upload.py +1298 -0
  8. ml_dash/client.py +955 -0
  9. ml_dash/config.py +114 -11
  10. ml_dash/experiment.py +1020 -0
  11. ml_dash/files.py +688 -0
  12. ml_dash/log.py +181 -0
  13. ml_dash/metric.py +292 -0
  14. ml_dash/params.py +188 -0
  15. ml_dash/storage.py +1115 -0
  16. ml_dash-0.5.9.dist-info/METADATA +244 -0
  17. ml_dash-0.5.9.dist-info/RECORD +20 -0
  18. ml_dash-0.5.9.dist-info/WHEEL +4 -0
  19. ml_dash-0.5.9.dist-info/entry_points.txt +3 -0
  20. ml_dash/app.py +0 -33
  21. ml_dash/file_events.py +0 -71
  22. ml_dash/file_handlers.py +0 -141
  23. ml_dash/file_utils.py +0 -5
  24. ml_dash/file_watcher.py +0 -30
  25. ml_dash/main.py +0 -60
  26. ml_dash/mime_types.py +0 -20
  27. ml_dash/schema/__init__.py +0 -110
  28. ml_dash/schema/archive.py +0 -165
  29. ml_dash/schema/directories.py +0 -59
  30. ml_dash/schema/experiments.py +0 -65
  31. ml_dash/schema/files/__init__.py +0 -204
  32. ml_dash/schema/files/file_helpers.py +0 -79
  33. ml_dash/schema/files/images.py +0 -27
  34. ml_dash/schema/files/metrics.py +0 -64
  35. ml_dash/schema/files/parameters.py +0 -50
  36. ml_dash/schema/files/series.py +0 -235
  37. ml_dash/schema/files/videos.py +0 -27
  38. ml_dash/schema/helpers.py +0 -66
  39. ml_dash/schema/projects.py +0 -65
  40. ml_dash/schema/schema_helpers.py +0 -19
  41. ml_dash/schema/users.py +0 -33
  42. ml_dash/sse.py +0 -18
  43. ml_dash-0.0.11.dist-info/METADATA +0 -67
  44. ml_dash-0.0.11.dist-info/RECORD +0 -30
  45. ml_dash-0.0.11.dist-info/WHEEL +0 -5
  46. ml_dash-0.0.11.dist-info/top_level.txt +0 -1
  47. /ml_dash/{example.py → py.typed} +0 -0
@@ -0,0 +1,797 @@
1
+ """
2
+ CLI command for downloading experiments from remote server to local storage.
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import tempfile
9
+ import threading
10
+ import time
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ from dataclasses import dataclass, field, asdict
13
+ from pathlib import Path
14
+ from typing import Optional, Dict, Any, List
15
+
16
+ from rich.console import Console
17
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TransferSpeedColumn
18
+ from rich.table import Table
19
+ from rich.panel import Panel
20
+
21
+ from ..client import RemoteClient
22
+ from ..storage import LocalStorage
23
+ from ..config import Config
24
+
25
+ console = Console()
26
+
27
+
28
+ # ============================================================================
29
+ # Data Classes
30
+ # ============================================================================
31
+
32
+ @dataclass
33
+ class ExperimentInfo:
34
+ """Information about an experiment to download."""
35
+ project: str
36
+ experiment: str
37
+ experiment_id: str
38
+ has_logs: bool = False
39
+ has_params: bool = False
40
+ metric_names: List[str] = field(default_factory=list)
41
+ file_count: int = 0
42
+ estimated_size: int = 0
43
+ log_count: int = 0
44
+ status: str = "RUNNING"
45
+ folder: Optional[str] = None
46
+ description: Optional[str] = None
47
+ tags: List[str] = field(default_factory=list)
48
+
49
+
50
+ @dataclass
51
+ class DownloadState:
52
+ """State for resuming interrupted downloads."""
53
+ remote_url: str
54
+ local_path: str
55
+ namespace: str
56
+ completed_experiments: List[str] = field(default_factory=list)
57
+ failed_experiments: List[str] = field(default_factory=list)
58
+ in_progress_experiment: Optional[str] = None
59
+ in_progress_items: Dict[str, Any] = field(default_factory=dict)
60
+ timestamp: Optional[str] = None
61
+
62
+ def to_dict(self) -> Dict[str, Any]:
63
+ """Convert to dictionary."""
64
+ return asdict(self)
65
+
66
+ @classmethod
67
+ def from_dict(cls, data: Dict[str, Any]) -> "DownloadState":
68
+ """Create from dictionary."""
69
+ return cls(**data)
70
+
71
+ def save(self, path: Path):
72
+ """Save state to JSON file."""
73
+ path.write_text(json.dumps(self.to_dict(), indent=2))
74
+
75
+ @classmethod
76
+ def load(cls, path: Path) -> Optional["DownloadState"]:
77
+ """Load state from JSON file."""
78
+ if not path.exists():
79
+ return None
80
+ try:
81
+ data = json.loads(path.read_text())
82
+ return cls.from_dict(data)
83
+ except Exception as e:
84
+ console.print(f"[yellow]Warning: Could not load state file: {e}[/yellow]")
85
+ return None
86
+
87
+
88
+ @dataclass
89
+ class DownloadResult:
90
+ """Result of downloading an experiment."""
91
+ experiment: str
92
+ success: bool = False
93
+ downloaded: Dict[str, int] = field(default_factory=dict)
94
+ failed: Dict[str, List[str]] = field(default_factory=dict)
95
+ errors: List[str] = field(default_factory=list)
96
+ bytes_downloaded: int = 0
97
+ skipped: bool = False
98
+
99
+
100
+ # ============================================================================
101
+ # Helper Functions
102
+ # ============================================================================
103
+
104
+ def _format_bytes(bytes_count: int) -> str:
105
+ """Format bytes as human-readable string."""
106
+ for unit in ["B", "KB", "MB", "GB"]:
107
+ if bytes_count < 1024:
108
+ return f"{bytes_count:.2f} {unit}"
109
+ bytes_count /= 1024
110
+ return f"{bytes_count:.2f} TB"
111
+
112
+
113
+ def _format_bytes_per_sec(bytes_per_sec: float) -> str:
114
+ """Format bytes per second as human-readable string."""
115
+ return f"{_format_bytes(bytes_per_sec)}/s"
116
+
117
+
118
+ def _experiment_from_graphql(graphql_data: Dict[str, Any]) -> ExperimentInfo:
119
+ """Convert GraphQL experiment data to ExperimentInfo."""
120
+ log_metadata = graphql_data.get('logMetadata') or {}
121
+
122
+ # Extract folder from metadata if it exists
123
+ metadata = graphql_data.get('metadata') or {}
124
+ folder = metadata.get('folder') if isinstance(metadata, dict) else None
125
+
126
+ return ExperimentInfo(
127
+ project=graphql_data['project']['slug'],
128
+ experiment=graphql_data['name'],
129
+ experiment_id=graphql_data['id'],
130
+ has_logs=log_metadata.get('totalLogs', 0) > 0,
131
+ has_params=graphql_data.get('parameters') is not None,
132
+ metric_names=[m['name'] for m in graphql_data.get('metrics', []) or []],
133
+ file_count=len(graphql_data.get('files', []) or []),
134
+ log_count=int(log_metadata.get('totalLogs', 0)),
135
+ status=graphql_data.get('status', 'RUNNING'),
136
+ folder=folder,
137
+ description=graphql_data.get('description'),
138
+ tags=graphql_data.get('tags', []) or [],
139
+ )
140
+
141
+
142
+ def discover_experiments(
143
+ remote_client: RemoteClient,
144
+ namespace: str,
145
+ project_filter: Optional[str] = None,
146
+ experiment_filter: Optional[str] = None,
147
+ ) -> List[ExperimentInfo]:
148
+ """
149
+ Discover experiments on remote server using GraphQL.
150
+
151
+ Args:
152
+ remote_client: Remote API client
153
+ namespace: Namespace slug
154
+ project_filter: Optional project slug filter
155
+ experiment_filter: Optional experiment name filter
156
+
157
+ Returns:
158
+ List of ExperimentInfo objects
159
+ """
160
+ # Specific experiment requested
161
+ if project_filter and experiment_filter:
162
+ exp_data = remote_client.get_experiment_graphql(namespace, project_filter, experiment_filter)
163
+ if exp_data:
164
+ return [_experiment_from_graphql(exp_data)]
165
+ return []
166
+
167
+ # Project filter - get all experiments in project
168
+ if project_filter:
169
+ experiments_data = remote_client.list_experiments_graphql(namespace, project_filter)
170
+ return [_experiment_from_graphql(exp) for exp in experiments_data]
171
+
172
+ # No filter - get all projects and their experiments
173
+ projects = remote_client.list_projects_graphql(namespace)
174
+ all_experiments = []
175
+ for project in projects:
176
+ experiments_data = remote_client.list_experiments_graphql(namespace, project['slug'])
177
+ all_experiments.extend([_experiment_from_graphql(exp) for exp in experiments_data])
178
+
179
+ return all_experiments
180
+
181
+
182
+ def _get_or_generate_api_key(args: argparse.Namespace, config: Config) -> str:
183
+ """Get API key from args, config, or generate from username."""
184
+ if args.api_key:
185
+ return args.api_key
186
+ if config.api_key:
187
+ return config.api_key
188
+ if args.username:
189
+ from ..cli_commands.upload import generate_api_key_from_username
190
+ return generate_api_key_from_username(args.username)
191
+ return ""
192
+
193
+
194
+ # ============================================================================
195
+ # Experiment Downloader
196
+ # ============================================================================
197
+
198
+ class ExperimentDownloader:
199
+ """Downloads a single experiment from remote server."""
200
+
201
+ def __init__(
202
+ self,
203
+ local_storage: LocalStorage,
204
+ remote_client: RemoteClient,
205
+ batch_size: int = 1000,
206
+ skip_logs: bool = False,
207
+ skip_metrics: bool = False,
208
+ skip_files: bool = False,
209
+ skip_params: bool = False,
210
+ verbose: bool = False,
211
+ max_concurrent_metrics: int = 5,
212
+ max_concurrent_files: int = 3,
213
+ ):
214
+ self.local = local_storage
215
+ self.remote = remote_client
216
+ self.batch_size = batch_size
217
+ self.skip_logs = skip_logs
218
+ self.skip_metrics = skip_metrics
219
+ self.skip_files = skip_files
220
+ self.skip_params = skip_params
221
+ self.verbose = verbose
222
+ self.max_concurrent_metrics = max_concurrent_metrics
223
+ self.max_concurrent_files = max_concurrent_files
224
+ self._lock = threading.Lock()
225
+ self._thread_local = threading.local()
226
+
227
+ def _get_remote_client(self) -> RemoteClient:
228
+ """Get thread-local remote client for safe concurrent access."""
229
+ if not hasattr(self._thread_local, 'client'):
230
+ self._thread_local.client = RemoteClient(
231
+ base_url=self.remote.base_url,
232
+ api_key=self.remote.api_key
233
+ )
234
+ return self._thread_local.client
235
+
236
+ def download_experiment(self, exp_info: ExperimentInfo) -> DownloadResult:
237
+ """Download a complete experiment."""
238
+ result = DownloadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
239
+
240
+ try:
241
+ if self.verbose:
242
+ console.print(f" [dim]Downloading {exp_info.project}/{exp_info.experiment}[/dim]")
243
+
244
+ # Step 1: Download metadata and create experiment
245
+ self._download_metadata(exp_info, result)
246
+
247
+ # Step 2: Download parameters
248
+ if not self.skip_params and exp_info.has_params:
249
+ self._download_parameters(exp_info, result)
250
+
251
+ # Step 3: Download logs
252
+ if not self.skip_logs and exp_info.has_logs:
253
+ self._download_logs(exp_info, result)
254
+
255
+ # Step 4: Download metrics (parallel)
256
+ if not self.skip_metrics and exp_info.metric_names:
257
+ self._download_metrics(exp_info, result)
258
+
259
+ # Step 5: Download files (parallel)
260
+ if not self.skip_files and exp_info.file_count > 0:
261
+ self._download_files(exp_info, result)
262
+
263
+ result.success = True
264
+
265
+ except Exception as e:
266
+ result.success = False
267
+ result.errors.append(str(e))
268
+ if self.verbose:
269
+ console.print(f" [red]Error: {e}[/red]")
270
+
271
+ return result
272
+
273
+ def _download_metadata(self, exp_info: ExperimentInfo, result: DownloadResult):
274
+ """Download and create experiment metadata."""
275
+ # Create experiment directory structure with folder path
276
+ self.local.create_experiment(
277
+ project=exp_info.project,
278
+ name=exp_info.experiment,
279
+ description=exp_info.description,
280
+ tags=exp_info.tags,
281
+ bindrs=[],
282
+ folder=exp_info.folder,
283
+ metadata=None,
284
+ )
285
+
286
+ def _download_parameters(self, exp_info: ExperimentInfo, result: DownloadResult):
287
+ """Download parameters."""
288
+ try:
289
+ params_data = self.remote.get_parameters(exp_info.experiment_id)
290
+ if params_data:
291
+ self.local.write_parameters(
292
+ project=exp_info.project,
293
+ experiment=exp_info.experiment,
294
+ data=params_data
295
+ )
296
+ result.downloaded["parameters"] = 1
297
+ result.bytes_downloaded += len(json.dumps(params_data))
298
+ except Exception as e:
299
+ result.failed.setdefault("parameters", []).append(str(e))
300
+
301
+ def _download_logs(self, exp_info: ExperimentInfo, result: DownloadResult):
302
+ """Download logs with pagination."""
303
+ try:
304
+ offset = 0
305
+ total_downloaded = 0
306
+
307
+ while True:
308
+ logs_data = self.remote.query_logs(
309
+ experiment_id=exp_info.experiment_id,
310
+ limit=self.batch_size,
311
+ offset=offset,
312
+ order_by="sequenceNumber",
313
+ order="asc"
314
+ )
315
+
316
+ logs = logs_data.get("logs", [])
317
+ if not logs:
318
+ break
319
+
320
+ # Write logs
321
+ for log in logs:
322
+ self.local.write_log(
323
+ project=exp_info.project,
324
+ experiment=exp_info.experiment,
325
+ message=log['message'],
326
+ level=log['level'],
327
+ timestamp=log['timestamp'],
328
+ metadata=log.get('metadata')
329
+ )
330
+
331
+ total_downloaded += len(logs)
332
+ result.bytes_downloaded += sum(len(json.dumps(log)) for log in logs)
333
+
334
+ if not logs_data.get("hasMore", False):
335
+ break
336
+
337
+ offset += len(logs)
338
+
339
+ result.downloaded["logs"] = total_downloaded
340
+
341
+ except Exception as e:
342
+ result.failed.setdefault("logs", []).append(str(e))
343
+
344
+ def _download_metrics(self, exp_info: ExperimentInfo, result: DownloadResult):
345
+ """Download all metrics in parallel."""
346
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
347
+ future_to_metric = {}
348
+
349
+ for metric_name in exp_info.metric_names:
350
+ future = executor.submit(
351
+ self._download_single_metric,
352
+ exp_info.experiment_id,
353
+ exp_info.project,
354
+ exp_info.experiment,
355
+ metric_name
356
+ )
357
+ future_to_metric[future] = metric_name
358
+
359
+ for future in as_completed(future_to_metric):
360
+ metric_name = future_to_metric[future]
361
+ metric_result = future.result()
362
+
363
+ with self._lock:
364
+ if metric_result['success']:
365
+ result.downloaded["metrics"] = result.downloaded.get("metrics", 0) + 1
366
+ result.bytes_downloaded += metric_result['bytes']
367
+ else:
368
+ result.failed.setdefault("metrics", []).append(
369
+ f"{metric_name}: {metric_result['error']}"
370
+ )
371
+
372
+ def _download_single_chunk(self, experiment_id: str, metric_name: str, chunk_number: int):
373
+ """Download a single chunk (for parallel downloading)."""
374
+ remote = self._get_remote_client()
375
+ try:
376
+ chunk_data = remote.download_metric_chunk(experiment_id, metric_name, chunk_number)
377
+ return {
378
+ 'success': True,
379
+ 'chunk_number': chunk_number,
380
+ 'data': chunk_data.get('data', []),
381
+ 'start_index': int(chunk_data.get('startIndex', 0)),
382
+ 'end_index': int(chunk_data.get('endIndex', 0)),
383
+ 'error': None
384
+ }
385
+ except Exception as e:
386
+ return {
387
+ 'success': False,
388
+ 'chunk_number': chunk_number,
389
+ 'error': str(e),
390
+ 'data': []
391
+ }
392
+
393
+ def _download_single_metric(self, experiment_id: str, project: str, experiment: str, metric_name: str):
394
+ """Download a single metric using chunk-aware approach (thread-safe)."""
395
+ remote = self._get_remote_client()
396
+
397
+ total_downloaded = 0
398
+ bytes_downloaded = 0
399
+
400
+ try:
401
+ # Get metric metadata to determine download strategy
402
+ metadata = remote.get_metric_stats(experiment_id, metric_name)
403
+ total_chunks = metadata.get('totalChunks', 0)
404
+ buffered_points = int(metadata.get('bufferedDataPoints', 0))
405
+
406
+ all_data = []
407
+
408
+ # Download chunks in parallel if they exist
409
+ if total_chunks > 0:
410
+ from concurrent.futures import ThreadPoolExecutor, as_completed
411
+
412
+ # Download all chunks in parallel (max 10 workers)
413
+ with ThreadPoolExecutor(max_workers=min(10, total_chunks)) as executor:
414
+ chunk_futures = {
415
+ executor.submit(self._download_single_chunk, experiment_id, metric_name, i): i
416
+ for i in range(total_chunks)
417
+ }
418
+
419
+ for future in as_completed(chunk_futures):
420
+ result = future.result()
421
+ if result['success']:
422
+ all_data.extend(result['data'])
423
+ else:
424
+ # If a chunk fails, fall back to pagination
425
+ raise Exception(f"Chunk {result['chunk_number']} download failed: {result['error']}")
426
+
427
+ # Download buffer data if exists
428
+ if buffered_points > 0:
429
+ response = remote.get_metric_data(
430
+ experiment_id, metric_name,
431
+ buffer_only=True
432
+ )
433
+ buffer_data = response.get('data', [])
434
+ all_data.extend(buffer_data)
435
+
436
+ # Sort all data by index
437
+ all_data.sort(key=lambda x: int(x.get('index', 0)))
438
+
439
+ # Write to local storage in batches
440
+ batch_size = 10000
441
+ for i in range(0, len(all_data), batch_size):
442
+ batch = all_data[i:i + batch_size]
443
+ self.local.append_batch_to_metric(
444
+ project, experiment, metric_name,
445
+ data_points=[d['data'] for d in batch]
446
+ )
447
+ total_downloaded += len(batch)
448
+ bytes_downloaded += sum(len(json.dumps(d)) for d in batch)
449
+
450
+ return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
451
+
452
+ except Exception as e:
453
+ # Fall back to pagination if chunk download fails
454
+ console.print(f"[yellow]Chunk download failed for {metric_name}, falling back to pagination: {e}[/yellow]")
455
+ return self._download_metric_with_pagination(experiment_id, project, experiment, metric_name)
456
+
457
+ def _download_metric_with_pagination(self, experiment_id: str, project: str, experiment: str, metric_name: str):
458
+ """Original pagination-based download (fallback method)."""
459
+ remote = self._get_remote_client()
460
+
461
+ total_downloaded = 0
462
+ bytes_downloaded = 0
463
+ start_index = 0
464
+
465
+ try:
466
+ while True:
467
+ response = remote.get_metric_data(
468
+ experiment_id, metric_name,
469
+ start_index=start_index,
470
+ limit=self.batch_size
471
+ )
472
+
473
+ data_points = response.get('data', [])
474
+ if not data_points:
475
+ break
476
+
477
+ # Write to local storage
478
+ self.local.append_batch_to_metric(
479
+ project, experiment, metric_name,
480
+ data_points=[d['data'] for d in data_points]
481
+ )
482
+
483
+ total_downloaded += len(data_points)
484
+ bytes_downloaded += sum(len(json.dumps(d)) for d in data_points)
485
+
486
+ if not response.get('hasMore', False):
487
+ break
488
+
489
+ start_index += len(data_points)
490
+
491
+ return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
492
+
493
+ except Exception as e:
494
+ return {'success': False, 'error': str(e), 'downloaded': 0, 'bytes': 0}
495
+
496
+ def _download_files(self, exp_info: ExperimentInfo, result: DownloadResult):
497
+ """Download files in parallel."""
498
+ # Get file list
499
+ try:
500
+ files_data = self.remote.list_files(exp_info.experiment_id)
501
+ except Exception as e:
502
+ result.failed.setdefault("files", []).append(f"List files failed: {e}")
503
+ return
504
+
505
+ if not files_data:
506
+ return
507
+
508
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_files) as executor:
509
+ future_to_file = {}
510
+
511
+ for file_info in files_data:
512
+ future = executor.submit(
513
+ self._download_single_file,
514
+ exp_info.experiment_id,
515
+ exp_info.project,
516
+ exp_info.experiment,
517
+ file_info
518
+ )
519
+ future_to_file[future] = file_info['filename']
520
+
521
+ for future in as_completed(future_to_file):
522
+ filename = future_to_file[future]
523
+ file_result = future.result()
524
+
525
+ with self._lock:
526
+ if file_result['success']:
527
+ result.downloaded["files"] = result.downloaded.get("files", 0) + 1
528
+ result.bytes_downloaded += file_result['bytes']
529
+ else:
530
+ result.failed.setdefault("files", []).append(
531
+ f"{filename}: {file_result['error']}"
532
+ )
533
+
534
+ def _download_single_file(self, experiment_id: str, project: str, experiment: str, file_info: Dict[str, Any]):
535
+ """Download a single file with streaming (thread-safe)."""
536
+ remote = self._get_remote_client()
537
+
538
+ try:
539
+ # Stream download to temp file
540
+ temp_fd, temp_path = tempfile.mkstemp(prefix="ml_dash_download_")
541
+ os.close(temp_fd)
542
+
543
+ remote.download_file_streaming(
544
+ experiment_id, file_info['id'], dest_path=temp_path
545
+ )
546
+
547
+ # Write to local storage
548
+ self.local.write_file(
549
+ project=project,
550
+ experiment=experiment,
551
+ file_path=temp_path,
552
+ prefix=file_info['path'],
553
+ filename=file_info['filename'],
554
+ description=file_info.get('description'),
555
+ tags=file_info.get('tags', []),
556
+ metadata=file_info.get('metadata'),
557
+ checksum=file_info['checksum'],
558
+ content_type=file_info['contentType'],
559
+ size_bytes=file_info['sizeBytes']
560
+ )
561
+
562
+ # Clean up temp file
563
+ os.remove(temp_path)
564
+
565
+ return {'success': True, 'bytes': file_info['sizeBytes'], 'error': None}
566
+
567
+ except Exception as e:
568
+ return {'success': False, 'error': str(e), 'bytes': 0}
569
+
570
+
571
+ # ============================================================================
572
+ # Main Command
573
+ # ============================================================================
574
+
575
+ def cmd_download(args: argparse.Namespace) -> int:
576
+ """Execute download command."""
577
+ # Load configuration
578
+ config = Config()
579
+ remote_url = args.remote or config.remote_url
580
+ api_key = _get_or_generate_api_key(args, config)
581
+ namespace = args.namespace or args.username
582
+
583
+ # Validate inputs
584
+ if not remote_url:
585
+ console.print("[red]Error:[/red] --remote is required (or set in config)")
586
+ return 1
587
+
588
+ if not api_key:
589
+ console.print("[red]Error:[/red] --api-key or --username is required")
590
+ return 1
591
+
592
+ if not namespace:
593
+ console.print("[red]Error:[/red] --namespace or --username is required")
594
+ return 1
595
+
596
+ # Initialize clients
597
+ remote_client = RemoteClient(base_url=remote_url, api_key=api_key)
598
+ local_storage = LocalStorage(root_path=Path(args.path))
599
+
600
+ # Load or create state
601
+ state_file = Path(args.state_file)
602
+ if args.resume:
603
+ state = DownloadState.load(state_file)
604
+ if state:
605
+ console.print(f"[cyan]Resuming from previous download ({len(state.completed_experiments)} completed)[/cyan]")
606
+ else:
607
+ console.print("[yellow]No previous state found, starting fresh[/yellow]")
608
+ state = DownloadState(
609
+ remote_url=remote_url,
610
+ local_path=str(args.path),
611
+ namespace=namespace
612
+ )
613
+ else:
614
+ state = DownloadState(
615
+ remote_url=remote_url,
616
+ local_path=str(args.path),
617
+ namespace=namespace
618
+ )
619
+
620
+ # Discover experiments
621
+ console.print("[bold]Discovering experiments on remote server...[/bold]")
622
+ try:
623
+ experiments = discover_experiments(
624
+ remote_client, namespace, args.project, args.experiment
625
+ )
626
+ except Exception as e:
627
+ console.print(f"[red]Failed to discover experiments: {e}[/red]")
628
+ return 1
629
+
630
+ if not experiments:
631
+ console.print("[yellow]No experiments found[/yellow]")
632
+ return 0
633
+
634
+ console.print(f"Found {len(experiments)} experiment(s)")
635
+
636
+ # Filter out completed experiments
637
+ experiments_to_download = []
638
+ for exp in experiments:
639
+ exp_key = f"{exp.project}/{exp.experiment}"
640
+
641
+ # Skip if already completed
642
+ if exp_key in state.completed_experiments and not args.overwrite:
643
+ console.print(f" [dim]Skipping {exp_key} (already completed)[/dim]")
644
+ continue
645
+
646
+ # Check if exists locally
647
+ exp_json = local_storage.root_path / exp.project / exp.experiment / "experiment.json"
648
+ if exp_json.exists() and not args.overwrite:
649
+ console.print(f" [yellow]Skipping {exp_key} (already exists locally)[/yellow]")
650
+ continue
651
+
652
+ experiments_to_download.append(exp)
653
+
654
+ if not experiments_to_download:
655
+ console.print("[green]All experiments already downloaded[/green]")
656
+ return 0
657
+
658
+ # Dry run mode
659
+ if args.dry_run:
660
+ console.print("\n[bold]Dry run - would download:[/bold]")
661
+ for exp in experiments_to_download:
662
+ console.print(f" • {exp.project}/{exp.experiment}")
663
+ console.print(f" Logs: {exp.log_count}, Metrics: {len(exp.metric_names)}, Files: {exp.file_count}")
664
+ return 0
665
+
666
+ # Download experiments
667
+ console.print(f"\n[bold]Downloading {len(experiments_to_download)} experiment(s)...[/bold]")
668
+ results = []
669
+ start_time = time.time()
670
+
671
+ for i, exp in enumerate(experiments_to_download, 1):
672
+ exp_key = f"{exp.project}/{exp.experiment}"
673
+ console.print(f"\n[cyan][{i}/{len(experiments_to_download)}] {exp_key}[/cyan]")
674
+
675
+ # Mark as in-progress
676
+ state.in_progress_experiment = exp_key
677
+ state.save(state_file)
678
+
679
+ # Download
680
+ downloader = ExperimentDownloader(
681
+ local_storage=local_storage,
682
+ remote_client=remote_client,
683
+ batch_size=args.batch_size,
684
+ skip_logs=args.skip_logs,
685
+ skip_metrics=args.skip_metrics,
686
+ skip_files=args.skip_files,
687
+ skip_params=args.skip_params,
688
+ verbose=args.verbose,
689
+ max_concurrent_metrics=args.max_concurrent_metrics,
690
+ max_concurrent_files=args.max_concurrent_files,
691
+ )
692
+
693
+ result = downloader.download_experiment(exp)
694
+ results.append(result)
695
+
696
+ # Update state
697
+ if result.success:
698
+ state.completed_experiments.append(exp_key)
699
+ console.print(f" [green]✓ Downloaded successfully[/green]")
700
+ else:
701
+ state.failed_experiments.append(exp_key)
702
+ console.print(f" [red]✗ Failed: {', '.join(result.errors)}[/red]")
703
+
704
+ state.in_progress_experiment = None
705
+ state.save(state_file)
706
+
707
+ # Show summary
708
+ end_time = time.time()
709
+ elapsed_time = end_time - start_time
710
+ total_bytes = sum(r.bytes_downloaded for r in results)
711
+ successful = sum(1 for r in results if r.success)
712
+
713
+ console.print("\n[bold]Download Summary[/bold]")
714
+ summary_table = Table()
715
+ summary_table.add_column("Metric", style="cyan")
716
+ summary_table.add_column("Value", style="green")
717
+
718
+ summary_table.add_row("Total Experiments", str(len(results)))
719
+ summary_table.add_row("Successful", str(successful))
720
+ summary_table.add_row("Failed", str(len(results) - successful))
721
+ summary_table.add_row("Total Data", _format_bytes(total_bytes))
722
+ summary_table.add_row("Total Time", f"{elapsed_time:.2f}s")
723
+
724
+ if elapsed_time > 0:
725
+ speed = total_bytes / elapsed_time
726
+ summary_table.add_row("Avg Speed", _format_bytes_per_sec(speed))
727
+
728
+ console.print(summary_table)
729
+
730
+ # Clean up state file if all successful
731
+ if all(r.success for r in results):
732
+ state_file.unlink(missing_ok=True)
733
+
734
+ return 0 if all(r.success for r in results) else 1
735
+
736
+
737
+ def add_parser(subparsers):
738
+ """Add download command parser."""
739
+ parser = subparsers.add_parser(
740
+ "download",
741
+ help="Download experiments from remote server to local storage"
742
+ )
743
+
744
+ # Positional arguments
745
+ parser.add_argument(
746
+ "path",
747
+ nargs="?",
748
+ default="./.ml-dash",
749
+ help="Local storage directory (default: ./.ml-dash)"
750
+ )
751
+
752
+ # Remote configuration
753
+ parser.add_argument("--remote", help="Remote server URL")
754
+ parser.add_argument("--api-key", help="JWT authentication token")
755
+ parser.add_argument("--username", help="Username for auto-generating API key")
756
+
757
+ # Scope control
758
+ parser.add_argument("--project", help="Download only this project")
759
+ parser.add_argument("--experiment", help="Download specific experiment (requires --project)")
760
+ parser.add_argument("--namespace", help="Namespace slug (defaults to username)")
761
+
762
+ # Data filtering
763
+ parser.add_argument("--skip-logs", action="store_true", help="Don't download logs")
764
+ parser.add_argument("--skip-metrics", action="store_true", help="Don't download metrics")
765
+ parser.add_argument("--skip-files", action="store_true", help="Don't download files")
766
+ parser.add_argument("--skip-params", action="store_true", help="Don't download parameters")
767
+
768
+ # Behavior control
769
+ parser.add_argument("--dry-run", action="store_true", help="Preview without downloading")
770
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite existing experiments")
771
+ parser.add_argument("--resume", action="store_true", help="Resume interrupted download")
772
+ parser.add_argument(
773
+ "--state-file",
774
+ default=".ml-dash-download-state.json",
775
+ help="State file path for resume (default: .ml-dash-download-state.json)"
776
+ )
777
+ parser.add_argument(
778
+ "--batch-size",
779
+ type=int,
780
+ default=1000,
781
+ help="Batch size for logs/metrics (default: 1000, max: 10000)"
782
+ )
783
+ parser.add_argument(
784
+ "--max-concurrent-metrics",
785
+ type=int,
786
+ default=5,
787
+ help="Parallel metric downloads (default: 5)"
788
+ )
789
+ parser.add_argument(
790
+ "--max-concurrent-files",
791
+ type=int,
792
+ default=3,
793
+ help="Parallel file downloads (default: 3)"
794
+ )
795
+ parser.add_argument("-v", "--verbose", action="store_true", help="Detailed progress output")
796
+
797
+ parser.set_defaults(func=cmd_download)