ml-dash 0.6.2rc1__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,18 +9,16 @@ import tempfile
9
9
  import threading
10
10
  import time
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
- from dataclasses import dataclass, field, asdict
12
+ from dataclasses import asdict, dataclass, field
13
13
  from pathlib import Path
14
- from typing import Optional, Dict, Any, List
14
+ from typing import Any, Dict, List, Optional
15
15
 
16
16
  from rich.console import Console
17
- from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TransferSpeedColumn
18
17
  from rich.table import Table
19
- from rich.panel import Panel
20
18
 
21
19
  from ..client import RemoteClient
22
- from ..storage import LocalStorage
23
20
  from ..config import Config
21
+ from ..storage import LocalStorage
24
22
 
25
23
  console = Console()
26
24
 
@@ -29,741 +27,850 @@ console = Console()
29
27
  # Data Classes
30
28
  # ============================================================================
31
29
 
30
+
32
31
  @dataclass
33
32
  class ExperimentInfo:
34
- """Information about an experiment to download."""
35
- project: str
36
- experiment: str
37
- experiment_id: str
38
- has_logs: bool = False
39
- has_params: bool = False
40
- metric_names: List[str] = field(default_factory=list)
41
- file_count: int = 0
42
- estimated_size: int = 0
43
- log_count: int = 0
44
- status: str = "RUNNING"
45
- folder: Optional[str] = None
46
- description: Optional[str] = None
47
- tags: List[str] = field(default_factory=list)
33
+ """Information about an experiment to download."""
34
+
35
+ project: str
36
+ experiment: str
37
+ experiment_id: str
38
+ owner: Optional[str] = None
39
+ has_logs: bool = False
40
+ has_params: bool = False
41
+ metric_names: List[str] = field(default_factory=list)
42
+ file_count: int = 0
43
+ estimated_size: int = 0
44
+ log_count: int = 0
45
+ status: str = "RUNNING"
46
+ prefix: Optional[str] = None
47
+ description: Optional[str] = None
48
+ tags: List[str] = field(default_factory=list)
48
49
 
49
50
 
50
51
  @dataclass
51
52
  class DownloadState:
52
- """State for resuming interrupted downloads."""
53
- remote_url: str
54
- local_path: str
55
- completed_experiments: List[str] = field(default_factory=list)
56
- failed_experiments: List[str] = field(default_factory=list)
57
- in_progress_experiment: Optional[str] = None
58
- in_progress_items: Dict[str, Any] = field(default_factory=dict)
59
- timestamp: Optional[str] = None
60
-
61
- def to_dict(self) -> Dict[str, Any]:
62
- """Convert to dictionary."""
63
- return asdict(self)
64
-
65
- @classmethod
66
- def from_dict(cls, data: Dict[str, Any]) -> "DownloadState":
67
- """Create from dictionary."""
68
- return cls(**data)
69
-
70
- def save(self, path: Path):
71
- """Save state to JSON file."""
72
- path.write_text(json.dumps(self.to_dict(), indent=2))
73
-
74
- @classmethod
75
- def load(cls, path: Path) -> Optional["DownloadState"]:
76
- """Load state from JSON file."""
77
- if not path.exists():
78
- return None
79
- try:
80
- data = json.loads(path.read_text())
81
- return cls.from_dict(data)
82
- except Exception as e:
83
- console.print(f"[yellow]Warning: Could not load state file: {e}[/yellow]")
84
- return None
53
+ """State for resuming interrupted downloads."""
54
+
55
+ remote_url: str
56
+ local_path: str
57
+ completed_experiments: List[str] = field(default_factory=list)
58
+ failed_experiments: List[str] = field(default_factory=list)
59
+ in_progress_experiment: Optional[str] = None
60
+ in_progress_items: Dict[str, Any] = field(default_factory=dict)
61
+ timestamp: Optional[str] = None
62
+
63
+ def to_dict(self) -> Dict[str, Any]:
64
+ """Convert to dictionary."""
65
+ return asdict(self)
66
+
67
+ @classmethod
68
+ def from_dict(cls, data: Dict[str, Any]) -> "DownloadState":
69
+ """Create from dictionary."""
70
+ return cls(**data)
71
+
72
+ def save(self, path: Path):
73
+ """Save state to JSON file."""
74
+ path.write_text(json.dumps(self.to_dict(), indent=2))
75
+
76
+ @classmethod
77
+ def load(cls, path: Path) -> Optional["DownloadState"]:
78
+ """Load state from JSON file."""
79
+ if not path.exists():
80
+ return None
81
+ try:
82
+ data = json.loads(path.read_text())
83
+ return cls.from_dict(data)
84
+ except Exception as e:
85
+ console.print(f"[yellow]Warning: Could not load state file: {e}[/yellow]")
86
+ return None
85
87
 
86
88
 
87
89
  @dataclass
88
90
  class DownloadResult:
89
- """Result of downloading an experiment."""
90
- experiment: str
91
- success: bool = False
92
- downloaded: Dict[str, int] = field(default_factory=dict)
93
- failed: Dict[str, List[str]] = field(default_factory=dict)
94
- errors: List[str] = field(default_factory=list)
95
- bytes_downloaded: int = 0
96
- skipped: bool = False
91
+ """Result of downloading an experiment."""
92
+
93
+ experiment: str
94
+ success: bool = False
95
+ downloaded: Dict[str, int] = field(default_factory=dict)
96
+ failed: Dict[str, List[str]] = field(default_factory=dict)
97
+ errors: List[str] = field(default_factory=list)
98
+ bytes_downloaded: int = 0
99
+ skipped: bool = False
97
100
 
98
101
 
99
102
  # ============================================================================
100
103
  # Helper Functions
101
104
  # ============================================================================
102
105
 
106
+
103
107
  def _format_bytes(bytes_count: int) -> str:
104
- """Format bytes as human-readable string."""
105
- for unit in ["B", "KB", "MB", "GB"]:
106
- if bytes_count < 1024:
107
- return f"{bytes_count:.2f} {unit}"
108
- bytes_count /= 1024
109
- return f"{bytes_count:.2f} TB"
108
+ """Format bytes as human-readable string."""
109
+ for unit in ["B", "KB", "MB", "GB"]:
110
+ if bytes_count < 1024:
111
+ return f"{bytes_count:.2f} {unit}"
112
+ bytes_count /= 1024
113
+ return f"{bytes_count:.2f} TB"
110
114
 
111
115
 
112
116
  def _format_bytes_per_sec(bytes_per_sec: float) -> str:
113
- """Format bytes per second as human-readable string."""
114
- return f"{_format_bytes(bytes_per_sec)}/s"
117
+ """Format bytes per second as human-readable string."""
118
+ return f"{_format_bytes(bytes_per_sec)}/s"
115
119
 
116
120
 
117
121
  def _experiment_from_graphql(graphql_data: Dict[str, Any]) -> ExperimentInfo:
118
- """Convert GraphQL experiment data to ExperimentInfo."""
119
- log_metadata = graphql_data.get('logMetadata') or {}
120
-
121
- # Extract folder from metadata if it exists
122
- metadata = graphql_data.get('metadata') or {}
123
- folder = metadata.get('folder') if isinstance(metadata, dict) else None
124
-
125
- return ExperimentInfo(
126
- project=graphql_data['project']['slug'],
127
- experiment=graphql_data['name'],
128
- experiment_id=graphql_data['id'],
129
- has_logs=log_metadata.get('totalLogs', 0) > 0,
130
- has_params=graphql_data.get('parameters') is not None,
131
- metric_names=[m['name'] for m in graphql_data.get('metrics', []) or []],
132
- file_count=len(graphql_data.get('files', []) or []),
133
- log_count=int(log_metadata.get('totalLogs', 0)),
134
- status=graphql_data.get('status', 'RUNNING'),
135
- folder=folder,
136
- description=graphql_data.get('description'),
137
- tags=graphql_data.get('tags', []) or [],
138
- )
122
+ """Convert GraphQL experiment data to ExperimentInfo."""
123
+ log_metadata = graphql_data.get("logMetadata") or {}
124
+
125
+ # Extract prefix from metadata if it exists
126
+ metadata = graphql_data.get("metadata") or {}
127
+ prefix = metadata.get("prefix") if isinstance(metadata, dict) else None
128
+
129
+ # Extract owner/namespace from project data
130
+ project_data = graphql_data.get("project", {})
131
+ namespace_data = project_data.get("namespace", {})
132
+ owner = namespace_data.get("slug") if namespace_data else None
133
+
134
+ return ExperimentInfo(
135
+ project=project_data.get("slug", "unknown"),
136
+ experiment=graphql_data["name"],
137
+ experiment_id=graphql_data["id"],
138
+ owner=owner,
139
+ has_logs=log_metadata.get("totalLogs", 0) > 0,
140
+ has_params=graphql_data.get("parameters") is not None,
141
+ metric_names=[m["name"] for m in graphql_data.get("metrics", []) or []],
142
+ file_count=len(graphql_data.get("files", []) or []),
143
+ log_count=int(log_metadata.get("totalLogs", 0)),
144
+ status=graphql_data.get("status", "RUNNING"),
145
+ prefix=prefix,
146
+ description=graphql_data.get("description"),
147
+ tags=graphql_data.get("tags", []) or [],
148
+ )
139
149
 
140
150
 
141
151
  def discover_experiments(
142
- remote_client: RemoteClient,
143
- project_filter: Optional[str] = None,
144
- experiment_filter: Optional[str] = None,
152
+ remote_client: RemoteClient,
153
+ project_filter: Optional[str] = None,
154
+ experiment_filter: Optional[str] = None,
145
155
  ) -> List[ExperimentInfo]:
146
- """
147
- Discover experiments on remote server using GraphQL.
148
-
149
- Args:
150
- remote_client: Remote API client
151
- project_filter: Optional project slug filter
152
- experiment_filter: Optional experiment name filter
153
-
154
- Returns:
155
- List of ExperimentInfo objects
156
- """
157
- # Specific experiment requested
158
- if project_filter and experiment_filter:
159
- exp_data = remote_client.get_experiment_graphql(project_filter, experiment_filter)
160
- if exp_data:
161
- return [_experiment_from_graphql(exp_data)]
162
- return []
163
-
164
- # Project filter - get all experiments in project
165
- if project_filter:
166
- experiments_data = remote_client.list_experiments_graphql(project_filter)
167
- return [_experiment_from_graphql(exp) for exp in experiments_data]
168
-
169
- # No filter - get all projects and their experiments
170
- projects = remote_client.list_projects_graphql()
171
- all_experiments = []
172
- for project in projects:
173
- experiments_data = remote_client.list_experiments_graphql(project['slug'])
174
- all_experiments.extend([_experiment_from_graphql(exp) for exp in experiments_data])
175
-
176
- return all_experiments
156
+ """
157
+ Discover experiments on remote server using GraphQL.
158
+
159
+ Args:
160
+ remote_client: Remote API client
161
+ project_filter: Optional project slug or glob pattern
162
+ experiment_filter: Optional experiment name filter
163
+
164
+ Returns:
165
+ List of ExperimentInfo objects
166
+ """
167
+ # Specific experiment requested
168
+ if project_filter and experiment_filter:
169
+ exp_data = remote_client.get_experiment_graphql(project_filter, experiment_filter)
170
+ if exp_data:
171
+ return [_experiment_from_graphql(exp_data)]
172
+ return []
173
+
174
+ # Check if project_filter contains glob wildcards
175
+ has_wildcards = project_filter and any(c in project_filter for c in ['*', '?', '['])
176
+
177
+ if has_wildcards:
178
+ # Use searchExperiments for glob patterns
179
+ # Expand simple project patterns to full namespace/project/experiment format
180
+ if '/' not in project_filter:
181
+ # Simple project pattern: "tut*" -> "*/tut*/*"
182
+ search_pattern = f"*/{project_filter}/*"
183
+ else:
184
+ # Full or partial pattern: use as-is
185
+ search_pattern = project_filter
186
+
187
+ experiments_data = remote_client.search_experiments_graphql(search_pattern)
188
+ return [_experiment_from_graphql(exp) for exp in experiments_data]
189
+
190
+ # Project filter - get all experiments in project (exact match)
191
+ if project_filter:
192
+ experiments_data = remote_client.list_experiments_graphql(project_filter)
193
+ return [_experiment_from_graphql(exp) for exp in experiments_data]
194
+
195
+ # No filter - get all projects and their experiments
196
+ projects = remote_client.list_projects_graphql()
197
+ all_experiments = []
198
+ for project in projects:
199
+ experiments_data = remote_client.list_experiments_graphql(project["slug"])
200
+ all_experiments.extend([_experiment_from_graphql(exp) for exp in experiments_data])
201
+
202
+ return all_experiments
177
203
 
178
204
 
179
205
  # ============================================================================
180
206
  # Experiment Downloader
181
207
  # ============================================================================
182
208
 
183
- class ExperimentDownloader:
184
- """Downloads a single experiment from remote server."""
185
-
186
- def __init__(
187
- self,
188
- local_storage: LocalStorage,
189
- remote_client: RemoteClient,
190
- batch_size: int = 1000,
191
- skip_logs: bool = False,
192
- skip_metrics: bool = False,
193
- skip_files: bool = False,
194
- skip_params: bool = False,
195
- verbose: bool = False,
196
- max_concurrent_metrics: int = 5,
197
- max_concurrent_files: int = 3,
198
- ):
199
- self.local = local_storage
200
- self.remote = remote_client
201
- self.batch_size = batch_size
202
- self.skip_logs = skip_logs
203
- self.skip_metrics = skip_metrics
204
- self.skip_files = skip_files
205
- self.skip_params = skip_params
206
- self.verbose = verbose
207
- self.max_concurrent_metrics = max_concurrent_metrics
208
- self.max_concurrent_files = max_concurrent_files
209
- self._lock = threading.Lock()
210
- self._thread_local = threading.local()
211
-
212
- def _get_remote_client(self) -> RemoteClient:
213
- """Get thread-local remote client for safe concurrent access."""
214
- if not hasattr(self._thread_local, 'client'):
215
- self._thread_local.client = RemoteClient(
216
- base_url=self.remote.base_url,
217
- api_key=self.remote.api_key
218
- )
219
- return self._thread_local.client
220
209
 
221
- def download_experiment(self, exp_info: ExperimentInfo) -> DownloadResult:
222
- """Download a complete experiment."""
223
- result = DownloadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
210
+ class ExperimentDownloader:
211
+ """Downloads a single experiment from remote server."""
224
212
 
225
- try:
226
- if self.verbose:
227
- console.print(f" [dim]Downloading {exp_info.project}/{exp_info.experiment}[/dim]")
213
+ def __init__(
214
+ self,
215
+ local_storage: LocalStorage,
216
+ remote_client: RemoteClient,
217
+ batch_size: int = 1000,
218
+ skip_logs: bool = False,
219
+ skip_metrics: bool = False,
220
+ skip_files: bool = False,
221
+ skip_params: bool = False,
222
+ verbose: bool = False,
223
+ max_concurrent_metrics: int = 5,
224
+ max_concurrent_files: int = 3,
225
+ ):
226
+ self.local = local_storage
227
+ self.remote = remote_client
228
+ self.batch_size = batch_size
229
+ self.skip_logs = skip_logs
230
+ self.skip_metrics = skip_metrics
231
+ self.skip_files = skip_files
232
+ self.skip_params = skip_params
233
+ self.verbose = verbose
234
+ self.max_concurrent_metrics = max_concurrent_metrics
235
+ self.max_concurrent_files = max_concurrent_files
236
+ self._lock = threading.Lock()
237
+ self._thread_local = threading.local()
238
+
239
+ def _get_remote_client(self) -> RemoteClient:
240
+ """Get thread-local remote client for safe concurrent access."""
241
+ if not hasattr(self._thread_local, "client"):
242
+ self._thread_local.client = RemoteClient(
243
+ base_url=self.remote.base_url,
244
+ namespace=self.remote.namespace,
245
+ api_key=self.remote.api_key
246
+ )
247
+ return self._thread_local.client
248
+
249
+ def download_experiment(self, exp_info: ExperimentInfo) -> DownloadResult:
250
+ """Download a complete experiment."""
251
+ result = DownloadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
228
252
 
229
- # Step 1: Download metadata and create experiment
230
- self._download_metadata(exp_info, result)
253
+ try:
254
+ if self.verbose:
255
+ console.print(
256
+ f" [dim]Downloading {exp_info.project}/{exp_info.experiment}[/dim]"
257
+ )
231
258
 
232
- # Step 2: Download parameters
233
- if not self.skip_params and exp_info.has_params:
234
- self._download_parameters(exp_info, result)
259
+ # Step 1: Download metadata and create experiment
260
+ self._download_metadata(exp_info, result)
235
261
 
236
- # Step 3: Download logs
237
- if not self.skip_logs and exp_info.has_logs:
238
- self._download_logs(exp_info, result)
262
+ # Step 2: Download parameters
263
+ if not self.skip_params and exp_info.has_params:
264
+ self._download_parameters(exp_info, result)
239
265
 
240
- # Step 4: Download metrics (parallel)
241
- if not self.skip_metrics and exp_info.metric_names:
242
- self._download_metrics(exp_info, result)
266
+ # Step 3: Download logs
267
+ if not self.skip_logs and exp_info.has_logs:
268
+ self._download_logs(exp_info, result)
243
269
 
244
- # Step 5: Download files (parallel)
245
- if not self.skip_files and exp_info.file_count > 0:
246
- self._download_files(exp_info, result)
270
+ # Step 4: Download metrics (parallel)
271
+ if not self.skip_metrics and exp_info.metric_names:
272
+ self._download_metrics(exp_info, result)
247
273
 
248
- result.success = True
274
+ # Step 5: Download files (parallel)
275
+ if not self.skip_files and exp_info.file_count > 0:
276
+ self._download_files(exp_info, result)
249
277
 
250
- except Exception as e:
251
- result.success = False
252
- result.errors.append(str(e))
253
- if self.verbose:
254
- console.print(f" [red]Error: {e}[/red]")
278
+ result.success = True
255
279
 
256
- return result
280
+ except Exception as e:
281
+ result.success = False
282
+ result.errors.append(str(e))
283
+ if self.verbose:
284
+ console.print(f" [red]Error: {e}[/red]")
285
+
286
+ return result
287
+
288
+ def _download_metadata(self, exp_info: ExperimentInfo, result: DownloadResult):
289
+ """Download and create experiment metadata."""
290
+ # Create experiment directory structure with prefix
291
+ # The prefix from server already includes the full path (owner/project/.../ experiment_name)
292
+ if exp_info.prefix:
293
+ full_prefix = exp_info.prefix
294
+ else:
295
+ # Build prefix from owner/project/experiment
296
+ if exp_info.owner:
297
+ full_prefix = f"{exp_info.owner}/{exp_info.project}/{exp_info.experiment}"
298
+ else:
299
+ full_prefix = f"{exp_info.project}/{exp_info.experiment}"
300
+
301
+ # Extract owner from prefix or use from exp_info
302
+ owner = exp_info.owner if exp_info.owner else full_prefix.split('/')[0]
303
+
304
+ self.local.create_experiment(
305
+ owner=owner,
306
+ project=exp_info.project,
307
+ prefix=full_prefix,
308
+ description=exp_info.description,
309
+ tags=exp_info.tags,
310
+ bindrs=[],
311
+ metadata=None,
312
+ )
257
313
 
258
- def _download_metadata(self, exp_info: ExperimentInfo, result: DownloadResult):
259
- """Download and create experiment metadata."""
260
- # Create experiment directory structure with folder path
261
- self.local.create_experiment(
262
- project=exp_info.project,
263
- name=exp_info.experiment,
264
- description=exp_info.description,
265
- tags=exp_info.tags,
266
- bindrs=[],
267
- folder=exp_info.folder,
268
- metadata=None,
314
+ def _download_parameters(self, exp_info: ExperimentInfo, result: DownloadResult):
315
+ """Download parameters."""
316
+ try:
317
+ params_data = self.remote.get_parameters(exp_info.experiment_id)
318
+ if params_data:
319
+ self.local.write_parameters(
320
+ project=exp_info.project, experiment=exp_info.experiment, data=params_data
269
321
  )
322
+ result.downloaded["parameters"] = 1
323
+ result.bytes_downloaded += len(json.dumps(params_data))
324
+ except Exception as e:
325
+ result.failed.setdefault("parameters", []).append(str(e))
270
326
 
271
- def _download_parameters(self, exp_info: ExperimentInfo, result: DownloadResult):
272
- """Download parameters."""
273
- try:
274
- params_data = self.remote.get_parameters(exp_info.experiment_id)
275
- if params_data:
276
- self.local.write_parameters(
277
- project=exp_info.project,
278
- experiment=exp_info.experiment,
279
- data=params_data
280
- )
281
- result.downloaded["parameters"] = 1
282
- result.bytes_downloaded += len(json.dumps(params_data))
283
- except Exception as e:
284
- result.failed.setdefault("parameters", []).append(str(e))
285
-
286
- def _download_logs(self, exp_info: ExperimentInfo, result: DownloadResult):
287
- """Download logs with pagination."""
288
- try:
289
- offset = 0
290
- total_downloaded = 0
291
-
292
- while True:
293
- logs_data = self.remote.query_logs(
294
- experiment_id=exp_info.experiment_id,
295
- limit=self.batch_size,
296
- offset=offset,
297
- order_by="sequenceNumber",
298
- order="asc"
299
- )
300
-
301
- logs = logs_data.get("logs", [])
302
- if not logs:
303
- break
304
-
305
- # Write logs
306
- for log in logs:
307
- self.local.write_log(
308
- project=exp_info.project,
309
- experiment=exp_info.experiment,
310
- message=log['message'],
311
- level=log['level'],
312
- timestamp=log['timestamp'],
313
- metadata=log.get('metadata')
314
- )
315
-
316
- total_downloaded += len(logs)
317
- result.bytes_downloaded += sum(len(json.dumps(log)) for log in logs)
318
-
319
- if not logs_data.get("hasMore", False):
320
- break
321
-
322
- offset += len(logs)
323
-
324
- result.downloaded["logs"] = total_downloaded
325
-
326
- except Exception as e:
327
- result.failed.setdefault("logs", []).append(str(e))
328
-
329
- def _download_metrics(self, exp_info: ExperimentInfo, result: DownloadResult):
330
- """Download all metrics in parallel."""
331
- with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
332
- future_to_metric = {}
333
-
334
- for metric_name in exp_info.metric_names:
335
- future = executor.submit(
336
- self._download_single_metric,
337
- exp_info.experiment_id,
338
- exp_info.project,
339
- exp_info.experiment,
340
- metric_name
341
- )
342
- future_to_metric[future] = metric_name
343
-
344
- for future in as_completed(future_to_metric):
345
- metric_name = future_to_metric[future]
346
- metric_result = future.result()
347
-
348
- with self._lock:
349
- if metric_result['success']:
350
- result.downloaded["metrics"] = result.downloaded.get("metrics", 0) + 1
351
- result.bytes_downloaded += metric_result['bytes']
352
- else:
353
- result.failed.setdefault("metrics", []).append(
354
- f"{metric_name}: {metric_result['error']}"
355
- )
356
-
357
- def _download_single_chunk(self, experiment_id: str, metric_name: str, chunk_number: int):
358
- """Download a single chunk (for parallel downloading)."""
359
- remote = self._get_remote_client()
360
- try:
361
- chunk_data = remote.download_metric_chunk(experiment_id, metric_name, chunk_number)
362
- return {
363
- 'success': True,
364
- 'chunk_number': chunk_number,
365
- 'data': chunk_data.get('data', []),
366
- 'start_index': int(chunk_data.get('startIndex', 0)),
367
- 'end_index': int(chunk_data.get('endIndex', 0)),
368
- 'error': None
369
- }
370
- except Exception as e:
371
- return {
372
- 'success': False,
373
- 'chunk_number': chunk_number,
374
- 'error': str(e),
375
- 'data': []
376
- }
377
-
378
- def _download_single_metric(self, experiment_id: str, project: str, experiment: str, metric_name: str):
379
- """Download a single metric using chunk-aware approach (thread-safe)."""
380
- remote = self._get_remote_client()
381
-
382
- total_downloaded = 0
383
- bytes_downloaded = 0
384
-
385
- try:
386
- # Get metric metadata to determine download strategy
387
- metadata = remote.get_metric_stats(experiment_id, metric_name)
388
- total_chunks = metadata.get('totalChunks', 0)
389
- buffered_points = int(metadata.get('bufferedDataPoints', 0))
390
-
391
- all_data = []
392
-
393
- # Download chunks in parallel if they exist
394
- if total_chunks > 0:
395
- from concurrent.futures import ThreadPoolExecutor, as_completed
396
-
397
- # Download all chunks in parallel (max 10 workers)
398
- with ThreadPoolExecutor(max_workers=min(10, total_chunks)) as executor:
399
- chunk_futures = {
400
- executor.submit(self._download_single_chunk, experiment_id, metric_name, i): i
401
- for i in range(total_chunks)
402
- }
403
-
404
- for future in as_completed(chunk_futures):
405
- result = future.result()
406
- if result['success']:
407
- all_data.extend(result['data'])
408
- else:
409
- # If a chunk fails, fall back to pagination
410
- raise Exception(f"Chunk {result['chunk_number']} download failed: {result['error']}")
411
-
412
- # Download buffer data if exists
413
- if buffered_points > 0:
414
- response = remote.get_metric_data(
415
- experiment_id, metric_name,
416
- buffer_only=True
417
- )
418
- buffer_data = response.get('data', [])
419
- all_data.extend(buffer_data)
420
-
421
- # Sort all data by index
422
- all_data.sort(key=lambda x: int(x.get('index', 0)))
423
-
424
- # Write to local storage in batches
425
- batch_size = 10000
426
- for i in range(0, len(all_data), batch_size):
427
- batch = all_data[i:i + batch_size]
428
- self.local.append_batch_to_metric(
429
- project, experiment, metric_name,
430
- data_points=[d['data'] for d in batch]
431
- )
432
- total_downloaded += len(batch)
433
- bytes_downloaded += sum(len(json.dumps(d)) for d in batch)
434
-
435
- return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
436
-
437
- except Exception as e:
438
- # Fall back to pagination if chunk download fails
439
- console.print(f"[yellow]Chunk download failed for {metric_name}, falling back to pagination: {e}[/yellow]")
440
- return self._download_metric_with_pagination(experiment_id, project, experiment, metric_name)
441
-
442
- def _download_metric_with_pagination(self, experiment_id: str, project: str, experiment: str, metric_name: str):
443
- """Original pagination-based download (fallback method)."""
444
- remote = self._get_remote_client()
445
-
446
- total_downloaded = 0
447
- bytes_downloaded = 0
448
- start_index = 0
449
-
450
- try:
451
- while True:
452
- response = remote.get_metric_data(
453
- experiment_id, metric_name,
454
- start_index=start_index,
455
- limit=self.batch_size
456
- )
457
-
458
- data_points = response.get('data', [])
459
- if not data_points:
460
- break
461
-
462
- # Write to local storage
463
- self.local.append_batch_to_metric(
464
- project, experiment, metric_name,
465
- data_points=[d['data'] for d in data_points]
466
- )
467
-
468
- total_downloaded += len(data_points)
469
- bytes_downloaded += sum(len(json.dumps(d)) for d in data_points)
470
-
471
- if not response.get('hasMore', False):
472
- break
473
-
474
- start_index += len(data_points)
475
-
476
- return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
477
-
478
- except Exception as e:
479
- return {'success': False, 'error': str(e), 'downloaded': 0, 'bytes': 0}
480
-
481
- def _download_files(self, exp_info: ExperimentInfo, result: DownloadResult):
482
- """Download files in parallel."""
483
- # Get file list
484
- try:
485
- files_data = self.remote.list_files(exp_info.experiment_id)
486
- except Exception as e:
487
- result.failed.setdefault("files", []).append(f"List files failed: {e}")
488
- return
489
-
490
- if not files_data:
491
- return
492
-
493
- with ThreadPoolExecutor(max_workers=self.max_concurrent_files) as executor:
494
- future_to_file = {}
495
-
496
- for file_info in files_data:
497
- future = executor.submit(
498
- self._download_single_file,
499
- exp_info.experiment_id,
500
- exp_info.project,
501
- exp_info.experiment,
502
- file_info
503
- )
504
- future_to_file[future] = file_info['filename']
505
-
506
- for future in as_completed(future_to_file):
507
- filename = future_to_file[future]
508
- file_result = future.result()
509
-
510
- with self._lock:
511
- if file_result['success']:
512
- result.downloaded["files"] = result.downloaded.get("files", 0) + 1
513
- result.bytes_downloaded += file_result['bytes']
514
- else:
515
- result.failed.setdefault("files", []).append(
516
- f"{filename}: {file_result['error']}"
517
- )
518
-
519
- def _download_single_file(self, experiment_id: str, project: str, experiment: str, file_info: Dict[str, Any]):
520
- """Download a single file with streaming (thread-safe)."""
521
- remote = self._get_remote_client()
522
-
523
- try:
524
- # Stream download to temp file
525
- temp_fd, temp_path = tempfile.mkstemp(prefix="ml_dash_download_")
526
- os.close(temp_fd)
527
-
528
- remote.download_file_streaming(
529
- experiment_id, file_info['id'], dest_path=temp_path
530
- )
327
+ def _download_logs(self, exp_info: ExperimentInfo, result: DownloadResult):
328
+ """Download logs with pagination."""
329
+ try:
330
+ offset = 0
331
+ total_downloaded = 0
332
+
333
+ while True:
334
+ logs_data = self.remote.query_logs(
335
+ experiment_id=exp_info.experiment_id,
336
+ limit=self.batch_size,
337
+ offset=offset,
338
+ order_by="sequenceNumber",
339
+ order="asc",
340
+ )
531
341
 
532
- # Write to local storage
533
- self.local.write_file(
534
- project=project,
535
- experiment=experiment,
536
- file_path=temp_path,
537
- prefix=file_info['path'],
538
- filename=file_info['filename'],
539
- description=file_info.get('description'),
540
- tags=file_info.get('tags', []),
541
- metadata=file_info.get('metadata'),
542
- checksum=file_info['checksum'],
543
- content_type=file_info['contentType'],
544
- size_bytes=file_info['sizeBytes']
545
- )
342
+ logs = logs_data.get("logs", [])
343
+ if not logs:
344
+ break
546
345
 
547
- # Clean up temp file
548
- os.remove(temp_path)
346
+ # Write logs
347
+ for log in logs:
348
+ self.local.write_log(
349
+ project=exp_info.project,
350
+ experiment=exp_info.experiment,
351
+ message=log["message"],
352
+ level=log["level"],
353
+ timestamp=log["timestamp"],
354
+ metadata=log.get("metadata"),
355
+ )
549
356
 
550
- return {'success': True, 'bytes': file_info['sizeBytes'], 'error': None}
357
+ total_downloaded += len(logs)
358
+ result.bytes_downloaded += sum(len(json.dumps(log)) for log in logs)
551
359
 
552
- except Exception as e:
553
- return {'success': False, 'error': str(e), 'bytes': 0}
360
+ if not logs_data.get("hasMore", False):
361
+ break
554
362
 
363
+ offset += len(logs)
555
364
 
556
- # ============================================================================
557
- # Main Command
558
- # ============================================================================
365
+ result.downloaded["logs"] = total_downloaded
559
366
 
560
- def cmd_download(args: argparse.Namespace) -> int:
561
- """Execute download command."""
562
- # Load configuration
563
- config = Config()
564
- remote_url = args.remote or config.remote_url
565
- api_key = args.api_key or config.api_key # RemoteClient will auto-load if None
566
-
567
- # Validate inputs
568
- if not remote_url:
569
- console.print("[red]Error:[/red] --remote is required (or set in config)")
570
- return 1
571
-
572
- # Initialize clients (RemoteClient will auto-load token if api_key is None)
573
- remote_client = RemoteClient(base_url=remote_url, api_key=api_key)
574
- local_storage = LocalStorage(root_path=Path(args.path))
575
-
576
- # Load or create state
577
- state_file = Path(args.state_file)
578
- if args.resume:
579
- state = DownloadState.load(state_file)
580
- if state:
581
- console.print(f"[cyan]Resuming from previous download ({len(state.completed_experiments)} completed)[/cyan]")
582
- else:
583
- console.print("[yellow]No previous state found, starting fresh[/yellow]")
584
- state = DownloadState(
585
- remote_url=remote_url,
586
- local_path=str(args.path)
587
- )
588
- else:
589
- state = DownloadState(
590
- remote_url=remote_url,
591
- local_path=str(args.path)
367
+ except Exception as e:
368
+ result.failed.setdefault("logs", []).append(str(e))
369
+
370
+ def _download_metrics(self, exp_info: ExperimentInfo, result: DownloadResult):
371
+ """Download all metrics in parallel."""
372
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
373
+ future_to_metric = {}
374
+
375
+ for metric_name in exp_info.metric_names:
376
+ future = executor.submit(
377
+ self._download_single_metric,
378
+ exp_info.experiment_id,
379
+ exp_info.project,
380
+ exp_info.experiment,
381
+ metric_name,
592
382
  )
383
+ future_to_metric[future] = metric_name
384
+
385
+ for future in as_completed(future_to_metric):
386
+ metric_name = future_to_metric[future]
387
+ metric_result = future.result()
388
+
389
+ with self._lock:
390
+ if metric_result["success"]:
391
+ result.downloaded["metrics"] = result.downloaded.get("metrics", 0) + 1
392
+ result.bytes_downloaded += metric_result["bytes"]
393
+ else:
394
+ result.failed.setdefault("metrics", []).append(
395
+ f"{metric_name}: {metric_result['error']}"
396
+ )
593
397
 
594
- # Discover experiments
595
- console.print("[bold]Discovering experiments on remote server...[/bold]")
398
+ def _download_single_chunk(
399
+ self, experiment_id: str, metric_name: str, chunk_number: int
400
+ ):
401
+ """Download a single chunk (for parallel downloading)."""
402
+ remote = self._get_remote_client()
596
403
  try:
597
- experiments = discover_experiments(
598
- remote_client, args.project, args.experiment
404
+ chunk_data = remote.download_metric_chunk(
405
+ experiment_id, metric_name, chunk_number
406
+ )
407
+ return {
408
+ "success": True,
409
+ "chunk_number": chunk_number,
410
+ "data": chunk_data.get("data", []),
411
+ "start_index": int(chunk_data.get("startIndex", 0)),
412
+ "end_index": int(chunk_data.get("endIndex", 0)),
413
+ "error": None,
414
+ }
415
+ except Exception as e:
416
+ return {
417
+ "success": False,
418
+ "chunk_number": chunk_number,
419
+ "error": str(e),
420
+ "data": [],
421
+ }
422
+
423
+ def _download_single_metric(
424
+ self, experiment_id: str, project: str, experiment: str, metric_name: str
425
+ ):
426
+ """Download a single metric using chunk-aware approach (thread-safe)."""
427
+ remote = self._get_remote_client()
428
+
429
+ total_downloaded = 0
430
+ bytes_downloaded = 0
431
+
432
+ try:
433
+ # Get metric metadata to determine download strategy
434
+ metadata = remote.get_metric_stats(experiment_id, metric_name)
435
+ total_chunks = metadata.get("totalChunks", 0)
436
+ buffered_points = int(metadata.get("bufferedDataPoints", 0))
437
+
438
+ all_data = []
439
+
440
+ # Download chunks in parallel if they exist
441
+ if total_chunks > 0:
442
+ from concurrent.futures import ThreadPoolExecutor, as_completed
443
+
444
+ # Download all chunks in parallel (max 10 workers)
445
+ with ThreadPoolExecutor(max_workers=min(10, total_chunks)) as executor:
446
+ chunk_futures = {
447
+ executor.submit(
448
+ self._download_single_chunk, experiment_id, metric_name, i
449
+ ): i
450
+ for i in range(total_chunks)
451
+ }
452
+
453
+ for future in as_completed(chunk_futures):
454
+ result = future.result()
455
+ if result["success"]:
456
+ all_data.extend(result["data"])
457
+ else:
458
+ # If a chunk fails, fall back to pagination
459
+ raise Exception(
460
+ f"Chunk {result['chunk_number']} download failed: {result['error']}"
461
+ )
462
+
463
+ # Download buffer data if exists
464
+ if buffered_points > 0:
465
+ response = remote.get_metric_data(experiment_id, metric_name, buffer_only=True)
466
+ buffer_data = response.get("data", [])
467
+ all_data.extend(buffer_data)
468
+
469
+ # Sort all data by index
470
+ all_data.sort(key=lambda x: int(x.get("index", 0)))
471
+
472
+ # Write to local storage in batches
473
+ batch_size = 10000
474
+ for i in range(0, len(all_data), batch_size):
475
+ batch = all_data[i : i + batch_size]
476
+ self.local.append_batch_to_metric(
477
+ project, experiment, metric_name, data_points=[d["data"] for d in batch]
599
478
  )
479
+ total_downloaded += len(batch)
480
+ bytes_downloaded += sum(len(json.dumps(d)) for d in batch)
481
+
482
+ return {
483
+ "success": True,
484
+ "downloaded": total_downloaded,
485
+ "bytes": bytes_downloaded,
486
+ "error": None,
487
+ }
488
+
600
489
  except Exception as e:
601
- console.print(f"[red]Failed to discover experiments: {e}[/red]")
602
- return 1
603
-
604
- if not experiments:
605
- console.print("[yellow]No experiments found[/yellow]")
606
- return 0
607
-
608
- console.print(f"Found {len(experiments)} experiment(s)")
609
-
610
- # Filter out completed experiments
611
- experiments_to_download = []
612
- for exp in experiments:
613
- exp_key = f"{exp.project}/{exp.experiment}"
614
-
615
- # Skip if already completed
616
- if exp_key in state.completed_experiments and not args.overwrite:
617
- console.print(f" [dim]Skipping {exp_key} (already completed)[/dim]")
618
- continue
619
-
620
- # Check if exists locally
621
- exp_json = local_storage.root_path / exp.project / exp.experiment / "experiment.json"
622
- if exp_json.exists() and not args.overwrite:
623
- console.print(f" [yellow]Skipping {exp_key} (already exists locally)[/yellow]")
624
- continue
625
-
626
- experiments_to_download.append(exp)
627
-
628
- if not experiments_to_download:
629
- console.print("[green]All experiments already downloaded[/green]")
630
- return 0
631
-
632
- # Dry run mode
633
- if args.dry_run:
634
- console.print("\n[bold]Dry run - would download:[/bold]")
635
- for exp in experiments_to_download:
636
- console.print(f" • {exp.project}/{exp.experiment}")
637
- console.print(f" Logs: {exp.log_count}, Metrics: {len(exp.metric_names)}, Files: {exp.file_count}")
638
- return 0
639
-
640
- # Download experiments
641
- console.print(f"\n[bold]Downloading {len(experiments_to_download)} experiment(s)...[/bold]")
642
- results = []
643
- start_time = time.time()
644
-
645
- for i, exp in enumerate(experiments_to_download, 1):
646
- exp_key = f"{exp.project}/{exp.experiment}"
647
- console.print(f"\n[cyan][{i}/{len(experiments_to_download)}] {exp_key}[/cyan]")
648
-
649
- # Mark as in-progress
650
- state.in_progress_experiment = exp_key
651
- state.save(state_file)
652
-
653
- # Download
654
- downloader = ExperimentDownloader(
655
- local_storage=local_storage,
656
- remote_client=remote_client,
657
- batch_size=args.batch_size,
658
- skip_logs=args.skip_logs,
659
- skip_metrics=args.skip_metrics,
660
- skip_files=args.skip_files,
661
- skip_params=args.skip_params,
662
- verbose=args.verbose,
663
- max_concurrent_metrics=args.max_concurrent_metrics,
664
- max_concurrent_files=args.max_concurrent_files,
490
+ # Fall back to pagination if chunk download fails
491
+ console.print(
492
+ f"[yellow]Chunk download failed for {metric_name}, falling back to pagination: {e}[/yellow]"
493
+ )
494
+ return self._download_metric_with_pagination(
495
+ experiment_id, project, experiment, metric_name
496
+ )
497
+
498
+ def _download_metric_with_pagination(
499
+ self, experiment_id: str, project: str, experiment: str, metric_name: str
500
+ ):
501
+ """Original pagination-based download (fallback method)."""
502
+ remote = self._get_remote_client()
503
+
504
+ total_downloaded = 0
505
+ bytes_downloaded = 0
506
+ start_index = 0
507
+
508
+ try:
509
+ while True:
510
+ response = remote.get_metric_data(
511
+ experiment_id, metric_name, start_index=start_index, limit=self.batch_size
665
512
  )
666
513
 
667
- result = downloader.download_experiment(exp)
668
- results.append(result)
514
+ data_points = response.get("data", [])
515
+ if not data_points:
516
+ break
669
517
 
670
- # Update state
671
- if result.success:
672
- state.completed_experiments.append(exp_key)
673
- console.print(f" [green]✓ Downloaded successfully[/green]")
674
- else:
675
- state.failed_experiments.append(exp_key)
676
- console.print(f" [red]✗ Failed: {', '.join(result.errors)}[/red]")
518
+ # Write to local storage
519
+ self.local.append_batch_to_metric(
520
+ project, experiment, metric_name, data_points=[d["data"] for d in data_points]
521
+ )
677
522
 
678
- state.in_progress_experiment = None
679
- state.save(state_file)
523
+ total_downloaded += len(data_points)
524
+ bytes_downloaded += sum(len(json.dumps(d)) for d in data_points)
680
525
 
681
- # Show summary
682
- end_time = time.time()
683
- elapsed_time = end_time - start_time
684
- total_bytes = sum(r.bytes_downloaded for r in results)
685
- successful = sum(1 for r in results if r.success)
526
+ if not response.get("hasMore", False):
527
+ break
686
528
 
687
- console.print("\n[bold]Download Summary[/bold]")
688
- summary_table = Table()
689
- summary_table.add_column("Metric", style="cyan")
690
- summary_table.add_column("Value", style="green")
529
+ start_index += len(data_points)
691
530
 
692
- summary_table.add_row("Total Experiments", str(len(results)))
693
- summary_table.add_row("Successful", str(successful))
694
- summary_table.add_row("Failed", str(len(results) - successful))
695
- summary_table.add_row("Total Data", _format_bytes(total_bytes))
696
- summary_table.add_row("Total Time", f"{elapsed_time:.2f}s")
531
+ return {
532
+ "success": True,
533
+ "downloaded": total_downloaded,
534
+ "bytes": bytes_downloaded,
535
+ "error": None,
536
+ }
697
537
 
698
- if elapsed_time > 0:
699
- speed = total_bytes / elapsed_time
700
- summary_table.add_row("Avg Speed", _format_bytes_per_sec(speed))
538
+ except Exception as e:
539
+ return {"success": False, "error": str(e), "downloaded": 0, "bytes": 0}
701
540
 
702
- console.print(summary_table)
541
+ def _download_files(self, exp_info: ExperimentInfo, result: DownloadResult):
542
+ """Download files in parallel."""
543
+ # Get file list
544
+ try:
545
+ files_data = self.remote.list_files(exp_info.experiment_id)
546
+ except Exception as e:
547
+ result.failed.setdefault("files", []).append(f"List files failed: {e}")
548
+ return
549
+
550
+ if not files_data:
551
+ return
552
+
553
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_files) as executor:
554
+ future_to_file = {}
555
+
556
+ for file_info in files_data:
557
+ future = executor.submit(
558
+ self._download_single_file,
559
+ exp_info.experiment_id,
560
+ exp_info.project,
561
+ exp_info.experiment,
562
+ file_info,
563
+ )
564
+ future_to_file[future] = file_info["filename"]
565
+
566
+ for future in as_completed(future_to_file):
567
+ filename = future_to_file[future]
568
+ file_result = future.result()
569
+
570
+ with self._lock:
571
+ if file_result["success"]:
572
+ result.downloaded["files"] = result.downloaded.get("files", 0) + 1
573
+ result.bytes_downloaded += file_result["bytes"]
574
+ else:
575
+ result.failed.setdefault("files", []).append(
576
+ f"{filename}: {file_result['error']}"
577
+ )
703
578
 
704
- # Clean up state file if all successful
705
- if all(r.success for r in results):
706
- state_file.unlink(missing_ok=True)
579
+ def _download_single_file(
580
+ self, experiment_id: str, project: str, experiment: str, file_info: Dict[str, Any]
581
+ ):
582
+ """Download a single file with streaming (thread-safe)."""
583
+ remote = self._get_remote_client()
707
584
 
708
- return 0 if all(r.success for r in results) else 1
585
+ try:
586
+ # Stream download to temp file
587
+ temp_fd, temp_path = tempfile.mkstemp(prefix="ml_dash_download_")
588
+ os.close(temp_fd)
589
+
590
+ remote.download_file_streaming(
591
+ experiment_id, file_info["id"], dest_path=temp_path
592
+ )
593
+
594
+ # Write to local storage
595
+ self.local.write_file(
596
+ project=project,
597
+ experiment=experiment,
598
+ file_path=temp_path,
599
+ prefix=file_info["path"],
600
+ filename=file_info["filename"],
601
+ description=file_info.get("description"),
602
+ tags=file_info.get("tags", []),
603
+ metadata=file_info.get("metadata"),
604
+ checksum=file_info["checksum"],
605
+ content_type=file_info["contentType"],
606
+ size_bytes=file_info["sizeBytes"],
607
+ )
608
+
609
+ # Clean up temp file
610
+ os.remove(temp_path)
611
+
612
+ return {"success": True, "bytes": file_info["sizeBytes"], "error": None}
709
613
 
614
+ except Exception as e:
615
+ return {"success": False, "error": str(e), "bytes": 0}
710
616
 
711
- def add_parser(subparsers):
712
- """Add download command parser."""
713
- parser = subparsers.add_parser(
714
- "download",
715
- help="Download experiments from remote server to local storage"
716
- )
717
617
 
718
- # Positional arguments
719
- parser.add_argument(
720
- "path",
721
- nargs="?",
722
- default="./.ml-dash",
723
- help="Local storage directory (default: ./.ml-dash)"
724
- )
618
+ # ============================================================================
619
+ # Main Command
620
+ # ============================================================================
725
621
 
726
- # Remote configuration
727
- parser.add_argument("--remote", help="Remote server URL")
728
- parser.add_argument("--api-key", help="JWT authentication token (optional - auto-loads from 'ml-dash login')")
729
-
730
- # Scope control
731
- parser.add_argument("--project", help="Download only this project")
732
- parser.add_argument("--experiment", help="Download specific experiment (requires --project)")
733
-
734
- # Data filtering
735
- parser.add_argument("--skip-logs", action="store_true", help="Don't download logs")
736
- parser.add_argument("--skip-metrics", action="store_true", help="Don't download metrics")
737
- parser.add_argument("--skip-files", action="store_true", help="Don't download files")
738
- parser.add_argument("--skip-params", action="store_true", help="Don't download parameters")
739
-
740
- # Behavior control
741
- parser.add_argument("--dry-run", action="store_true", help="Preview without downloading")
742
- parser.add_argument("--overwrite", action="store_true", help="Overwrite existing experiments")
743
- parser.add_argument("--resume", action="store_true", help="Resume interrupted download")
744
- parser.add_argument(
745
- "--state-file",
746
- default=".ml-dash-download-state.json",
747
- help="State file path for resume (default: .ml-dash-download-state.json)"
748
- )
749
- parser.add_argument(
750
- "--batch-size",
751
- type=int,
752
- default=1000,
753
- help="Batch size for logs/metrics (default: 1000, max: 10000)"
622
+
623
+ def cmd_download(args: argparse.Namespace) -> int:
624
+ """Execute download command."""
625
+ # Load configuration
626
+ config = Config()
627
+ remote_url = args.dash_url or config.remote_url
628
+ api_key = args.api_key or config.api_key # RemoteClient will auto-load if None
629
+
630
+ # Validate inputs
631
+ if not remote_url:
632
+ console.print("[red]Error:[/red] --dash-url is required (or set in config)")
633
+ return 1
634
+
635
+ # Extract namespace from project argument
636
+ namespace = None
637
+ if args.project:
638
+ # Parse namespace from project filter (format: "owner/project" or "owner/project/exp")
639
+ project_parts = args.project.strip("/").split("/")
640
+ if len(project_parts) >= 2: # Has at least "owner/project"
641
+ namespace = project_parts[0]
642
+
643
+ if not namespace:
644
+ console.print(
645
+ "[red]Error:[/red] --project must be in format 'namespace/project' or 'namespace/project/exp'"
754
646
  )
755
- parser.add_argument(
756
- "--max-concurrent-metrics",
757
- type=int,
758
- default=5,
759
- help="Parallel metric downloads (default: 5)"
647
+ console.print("Example: ml-dash download --project alice/my-project")
648
+ return 1
649
+
650
+ # Initialize clients (RemoteClient will auto-load token if api_key is None)
651
+ remote_client = RemoteClient(base_url=remote_url, namespace=namespace, api_key=api_key)
652
+ local_storage = LocalStorage(root_path=Path(args.path))
653
+
654
+ # Load or create state
655
+ state_file = Path(args.state_file)
656
+ if args.resume:
657
+ state = DownloadState.load(state_file)
658
+ if state:
659
+ console.print(
660
+ f"[cyan]Resuming from previous download ({len(state.completed_experiments)} completed)[/cyan]"
661
+ )
662
+ else:
663
+ console.print("[yellow]No previous state found, starting fresh[/yellow]")
664
+ state = DownloadState(remote_url=remote_url, local_path=str(args.path))
665
+ else:
666
+ state = DownloadState(remote_url=remote_url, local_path=str(args.path))
667
+
668
+ # Discover experiments
669
+ console.print("[bold]Discovering experiments on remote server...[/bold]")
670
+ try:
671
+ experiments = discover_experiments(remote_client, args.project, args.experiment)
672
+ except Exception as e:
673
+ console.print(f"[red]Failed to discover experiments: {e}[/red]")
674
+ return 1
675
+
676
+ if not experiments:
677
+ console.print("[yellow]No experiments found[/yellow]")
678
+ return 0
679
+
680
+ console.print(f"Found {len(experiments)} experiment(s)")
681
+
682
+ # Filter out completed experiments
683
+ experiments_to_download = []
684
+ for exp in experiments:
685
+ exp_key = f"{exp.project}/{exp.experiment}"
686
+
687
+ # Skip if already completed
688
+ if exp_key in state.completed_experiments and not args.overwrite:
689
+ console.print(f" [dim]Skipping {exp_key} (already completed)[/dim]")
690
+ continue
691
+
692
+ # Check if exists locally
693
+ exp_json = (
694
+ local_storage.root_path / exp.project / exp.experiment / "experiment.json"
760
695
  )
761
- parser.add_argument(
762
- "--max-concurrent-files",
763
- type=int,
764
- default=3,
765
- help="Parallel file downloads (default: 3)"
696
+ if exp_json.exists() and not args.overwrite:
697
+ console.print(f" [yellow]Skipping {exp_key} (already exists locally)[/yellow]")
698
+ continue
699
+
700
+ experiments_to_download.append(exp)
701
+
702
+ if not experiments_to_download:
703
+ console.print("[green]All experiments already downloaded[/green]")
704
+ return 0
705
+
706
+ # Dry run mode
707
+ if args.dry_run:
708
+ console.print("\n[bold]Dry run - would download:[/bold]")
709
+ for exp in experiments_to_download:
710
+ console.print(f" • {exp.project}/{exp.experiment}")
711
+ console.print(
712
+ f" Logs: {exp.log_count}, Metrics: {len(exp.metric_names)}, Files: {exp.file_count}"
713
+ )
714
+ return 0
715
+
716
+ # Download experiments
717
+ console.print(
718
+ f"\n[bold]Downloading {len(experiments_to_download)} experiment(s)...[/bold]"
719
+ )
720
+ results = []
721
+ start_time = time.time()
722
+
723
+ for i, exp in enumerate(experiments_to_download, 1):
724
+ exp_key = f"{exp.project}/{exp.experiment}"
725
+ console.print(f"\n[cyan][{i}/{len(experiments_to_download)}] {exp_key}[/cyan]")
726
+
727
+ # Mark as in-progress
728
+ state.in_progress_experiment = exp_key
729
+ state.save(state_file)
730
+
731
+ # Download
732
+ downloader = ExperimentDownloader(
733
+ local_storage=local_storage,
734
+ remote_client=remote_client,
735
+ batch_size=args.batch_size,
736
+ skip_logs=args.skip_logs,
737
+ skip_metrics=args.skip_metrics,
738
+ skip_files=args.skip_files,
739
+ skip_params=args.skip_params,
740
+ verbose=args.verbose,
741
+ max_concurrent_metrics=args.max_concurrent_metrics,
742
+ max_concurrent_files=args.max_concurrent_files,
766
743
  )
767
- parser.add_argument("-v", "--verbose", action="store_true", help="Detailed progress output")
768
744
 
769
- parser.set_defaults(func=cmd_download)
745
+ result = downloader.download_experiment(exp)
746
+ results.append(result)
747
+
748
+ # Update state
749
+ if result.success:
750
+ state.completed_experiments.append(exp_key)
751
+ console.print(" [green]✓ Downloaded successfully[/green]")
752
+ else:
753
+ state.failed_experiments.append(exp_key)
754
+ console.print(f" [red]✗ Failed: {', '.join(result.errors)}[/red]")
755
+
756
+ state.in_progress_experiment = None
757
+ state.save(state_file)
758
+
759
+ # Show summary
760
+ end_time = time.time()
761
+ elapsed_time = end_time - start_time
762
+ total_bytes = sum(r.bytes_downloaded for r in results)
763
+ successful = sum(1 for r in results if r.success)
764
+
765
+ console.print("\n[bold]Download Summary[/bold]")
766
+ summary_table = Table()
767
+ summary_table.add_column("Metric", style="cyan")
768
+ summary_table.add_column("Value", style="green")
769
+
770
+ summary_table.add_row("Total Experiments", str(len(results)))
771
+ summary_table.add_row("Successful", str(successful))
772
+ summary_table.add_row("Failed", str(len(results) - successful))
773
+ summary_table.add_row("Total Data", _format_bytes(total_bytes))
774
+ summary_table.add_row("Total Time", f"{elapsed_time:.2f}s")
775
+
776
+ if elapsed_time > 0:
777
+ speed = total_bytes / elapsed_time
778
+ summary_table.add_row("Avg Speed", _format_bytes_per_sec(speed))
779
+
780
+ console.print(summary_table)
781
+
782
+ # Clean up state file if all successful
783
+ if all(r.success for r in results):
784
+ state_file.unlink(missing_ok=True)
785
+
786
+ return 0 if all(r.success for r in results) else 1
787
+
788
+
789
+ def add_parser(subparsers):
790
+ """Add download command parser."""
791
+ parser = subparsers.add_parser(
792
+ "download", help="Download experiments from remote server to local storage"
793
+ )
794
+
795
+ # Positional arguments
796
+ parser.add_argument(
797
+ "path",
798
+ nargs="?",
799
+ default="./.dash",
800
+ help="Local storage directory (default: ./.dash)",
801
+ )
802
+
803
+ # Remote configuration
804
+ parser.add_argument(
805
+ "--dash-url",
806
+ help="ML-Dash server URL (defaults to config or https://api.dash.ml)",
807
+ )
808
+ parser.add_argument(
809
+ "--api-key",
810
+ help="JWT authentication token (optional - auto-loads from 'ml-dash login')",
811
+ )
812
+
813
+ # Scope control
814
+ parser.add_argument(
815
+ "-p",
816
+ "--pref",
817
+ "--prefix",
818
+ "--proj",
819
+ "--project",
820
+ dest="project",
821
+ type=str,
822
+ help="Filter experiments by project or pattern (supports glob: 'tut*', 'tom*/tutorials/*')"
823
+ )
824
+ parser.add_argument(
825
+ "--experiment",
826
+ help="Download specific experiment (requires --project)"
827
+ )
828
+
829
+ # Data filtering
830
+ parser.add_argument("--skip-logs", action="store_true", help="Don't download logs")
831
+ parser.add_argument(
832
+ "--skip-metrics", action="store_true", help="Don't download metrics"
833
+ )
834
+ parser.add_argument("--skip-files", action="store_true", help="Don't download files")
835
+ parser.add_argument(
836
+ "--skip-params", action="store_true", help="Don't download parameters"
837
+ )
838
+
839
+ # Behavior control
840
+ parser.add_argument(
841
+ "--dry-run", action="store_true", help="Preview without downloading"
842
+ )
843
+ parser.add_argument(
844
+ "--overwrite", action="store_true", help="Overwrite existing experiments"
845
+ )
846
+ parser.add_argument(
847
+ "--resume", action="store_true", help="Resume interrupted download"
848
+ )
849
+ parser.add_argument(
850
+ "--state-file",
851
+ default=".dash-download-state.json",
852
+ help="State file path for resume (default: .dash-download-state.json)",
853
+ )
854
+ parser.add_argument(
855
+ "--batch-size",
856
+ type=int,
857
+ default=1000,
858
+ help="Batch size for logs/metrics (default: 1000, max: 10000)",
859
+ )
860
+ parser.add_argument(
861
+ "--max-concurrent-metrics",
862
+ type=int,
863
+ default=5,
864
+ help="Parallel metric downloads (default: 5)",
865
+ )
866
+ parser.add_argument(
867
+ "--max-concurrent-files",
868
+ type=int,
869
+ default=3,
870
+ help="Parallel file downloads (default: 3)",
871
+ )
872
+ parser.add_argument(
873
+ "-v", "--verbose", action="store_true", help="Detailed progress output"
874
+ )
875
+
876
+ parser.set_defaults(func=cmd_download)