ml-dash 0.6.2rc1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml_dash/__init__.py +36 -64
- ml_dash/auth/token_storage.py +267 -226
- ml_dash/auto_start.py +28 -15
- ml_dash/cli.py +16 -2
- ml_dash/cli_commands/api.py +165 -0
- ml_dash/cli_commands/download.py +757 -667
- ml_dash/cli_commands/list.py +146 -13
- ml_dash/cli_commands/login.py +190 -183
- ml_dash/cli_commands/profile.py +92 -0
- ml_dash/cli_commands/upload.py +1291 -1141
- ml_dash/client.py +79 -6
- ml_dash/config.py +119 -119
- ml_dash/experiment.py +1234 -1034
- ml_dash/files.py +339 -224
- ml_dash/log.py +7 -7
- ml_dash/metric.py +359 -100
- ml_dash/params.py +6 -6
- ml_dash/remote_auto_start.py +20 -17
- ml_dash/run.py +211 -65
- ml_dash/snowflake.py +173 -0
- ml_dash/storage.py +1051 -1081
- {ml_dash-0.6.2rc1.dist-info → ml_dash-0.6.3.dist-info}/METADATA +12 -14
- ml_dash-0.6.3.dist-info/RECORD +33 -0
- {ml_dash-0.6.2rc1.dist-info → ml_dash-0.6.3.dist-info}/WHEEL +1 -1
- ml_dash-0.6.2rc1.dist-info/RECORD +0 -30
- {ml_dash-0.6.2rc1.dist-info → ml_dash-0.6.3.dist-info}/entry_points.txt +0 -0
ml_dash/cli_commands/download.py
CHANGED
|
@@ -9,18 +9,16 @@ import tempfile
|
|
|
9
9
|
import threading
|
|
10
10
|
import time
|
|
11
11
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
-
from dataclasses import dataclass, field
|
|
12
|
+
from dataclasses import asdict, dataclass, field
|
|
13
13
|
from pathlib import Path
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
15
|
|
|
16
16
|
from rich.console import Console
|
|
17
|
-
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TransferSpeedColumn
|
|
18
17
|
from rich.table import Table
|
|
19
|
-
from rich.panel import Panel
|
|
20
18
|
|
|
21
19
|
from ..client import RemoteClient
|
|
22
|
-
from ..storage import LocalStorage
|
|
23
20
|
from ..config import Config
|
|
21
|
+
from ..storage import LocalStorage
|
|
24
22
|
|
|
25
23
|
console = Console()
|
|
26
24
|
|
|
@@ -29,741 +27,833 @@ console = Console()
|
|
|
29
27
|
# Data Classes
|
|
30
28
|
# ============================================================================
|
|
31
29
|
|
|
30
|
+
|
|
32
31
|
@dataclass
|
|
33
32
|
class ExperimentInfo:
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
33
|
+
"""Information about an experiment to download."""
|
|
34
|
+
|
|
35
|
+
project: str
|
|
36
|
+
experiment: str
|
|
37
|
+
experiment_id: str
|
|
38
|
+
owner: Optional[str] = None
|
|
39
|
+
has_logs: bool = False
|
|
40
|
+
has_params: bool = False
|
|
41
|
+
metric_names: List[str] = field(default_factory=list)
|
|
42
|
+
file_count: int = 0
|
|
43
|
+
estimated_size: int = 0
|
|
44
|
+
log_count: int = 0
|
|
45
|
+
status: str = "RUNNING"
|
|
46
|
+
prefix: Optional[str] = None
|
|
47
|
+
description: Optional[str] = None
|
|
48
|
+
tags: List[str] = field(default_factory=list)
|
|
48
49
|
|
|
49
50
|
|
|
50
51
|
@dataclass
|
|
51
52
|
class DownloadState:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
53
|
+
"""State for resuming interrupted downloads."""
|
|
54
|
+
|
|
55
|
+
remote_url: str
|
|
56
|
+
local_path: str
|
|
57
|
+
completed_experiments: List[str] = field(default_factory=list)
|
|
58
|
+
failed_experiments: List[str] = field(default_factory=list)
|
|
59
|
+
in_progress_experiment: Optional[str] = None
|
|
60
|
+
in_progress_items: Dict[str, Any] = field(default_factory=dict)
|
|
61
|
+
timestamp: Optional[str] = None
|
|
62
|
+
|
|
63
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
64
|
+
"""Convert to dictionary."""
|
|
65
|
+
return asdict(self)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_dict(cls, data: Dict[str, Any]) -> "DownloadState":
|
|
69
|
+
"""Create from dictionary."""
|
|
70
|
+
return cls(**data)
|
|
71
|
+
|
|
72
|
+
def save(self, path: Path):
|
|
73
|
+
"""Save state to JSON file."""
|
|
74
|
+
path.write_text(json.dumps(self.to_dict(), indent=2))
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def load(cls, path: Path) -> Optional["DownloadState"]:
|
|
78
|
+
"""Load state from JSON file."""
|
|
79
|
+
if not path.exists():
|
|
80
|
+
return None
|
|
81
|
+
try:
|
|
82
|
+
data = json.loads(path.read_text())
|
|
83
|
+
return cls.from_dict(data)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
console.print(f"[yellow]Warning: Could not load state file: {e}[/yellow]")
|
|
86
|
+
return None
|
|
85
87
|
|
|
86
88
|
|
|
87
89
|
@dataclass
|
|
88
90
|
class DownloadResult:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
91
|
+
"""Result of downloading an experiment."""
|
|
92
|
+
|
|
93
|
+
experiment: str
|
|
94
|
+
success: bool = False
|
|
95
|
+
downloaded: Dict[str, int] = field(default_factory=dict)
|
|
96
|
+
failed: Dict[str, List[str]] = field(default_factory=dict)
|
|
97
|
+
errors: List[str] = field(default_factory=list)
|
|
98
|
+
bytes_downloaded: int = 0
|
|
99
|
+
skipped: bool = False
|
|
97
100
|
|
|
98
101
|
|
|
99
102
|
# ============================================================================
|
|
100
103
|
# Helper Functions
|
|
101
104
|
# ============================================================================
|
|
102
105
|
|
|
106
|
+
|
|
103
107
|
def _format_bytes(bytes_count: int) -> str:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
"""Format bytes as human-readable string."""
|
|
109
|
+
for unit in ["B", "KB", "MB", "GB"]:
|
|
110
|
+
if bytes_count < 1024:
|
|
111
|
+
return f"{bytes_count:.2f} {unit}"
|
|
112
|
+
bytes_count /= 1024
|
|
113
|
+
return f"{bytes_count:.2f} TB"
|
|
110
114
|
|
|
111
115
|
|
|
112
116
|
def _format_bytes_per_sec(bytes_per_sec: float) -> str:
|
|
113
|
-
|
|
114
|
-
|
|
117
|
+
"""Format bytes per second as human-readable string."""
|
|
118
|
+
return f"{_format_bytes(bytes_per_sec)}/s"
|
|
115
119
|
|
|
116
120
|
|
|
117
121
|
def _experiment_from_graphql(graphql_data: Dict[str, Any]) -> ExperimentInfo:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
)
|
|
122
|
+
"""Convert GraphQL experiment data to ExperimentInfo."""
|
|
123
|
+
log_metadata = graphql_data.get("logMetadata") or {}
|
|
124
|
+
|
|
125
|
+
# Extract prefix from metadata if it exists
|
|
126
|
+
metadata = graphql_data.get("metadata") or {}
|
|
127
|
+
prefix = metadata.get("prefix") if isinstance(metadata, dict) else None
|
|
128
|
+
|
|
129
|
+
# Extract owner/namespace from project data
|
|
130
|
+
project_data = graphql_data.get("project", {})
|
|
131
|
+
namespace_data = project_data.get("namespace", {})
|
|
132
|
+
owner = namespace_data.get("slug") if namespace_data else None
|
|
133
|
+
|
|
134
|
+
return ExperimentInfo(
|
|
135
|
+
project=project_data.get("slug", "unknown"),
|
|
136
|
+
experiment=graphql_data["name"],
|
|
137
|
+
experiment_id=graphql_data["id"],
|
|
138
|
+
owner=owner,
|
|
139
|
+
has_logs=log_metadata.get("totalLogs", 0) > 0,
|
|
140
|
+
has_params=graphql_data.get("parameters") is not None,
|
|
141
|
+
metric_names=[m["name"] for m in graphql_data.get("metrics", []) or []],
|
|
142
|
+
file_count=len(graphql_data.get("files", []) or []),
|
|
143
|
+
log_count=int(log_metadata.get("totalLogs", 0)),
|
|
144
|
+
status=graphql_data.get("status", "RUNNING"),
|
|
145
|
+
prefix=prefix,
|
|
146
|
+
description=graphql_data.get("description"),
|
|
147
|
+
tags=graphql_data.get("tags", []) or [],
|
|
148
|
+
)
|
|
139
149
|
|
|
140
150
|
|
|
141
151
|
def discover_experiments(
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
152
|
+
remote_client: RemoteClient,
|
|
153
|
+
project_filter: Optional[str] = None,
|
|
154
|
+
experiment_filter: Optional[str] = None,
|
|
145
155
|
) -> List[ExperimentInfo]:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
#
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
156
|
+
"""
|
|
157
|
+
Discover experiments on remote server using GraphQL.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
remote_client: Remote API client
|
|
161
|
+
project_filter: Optional project slug or glob pattern
|
|
162
|
+
experiment_filter: Optional experiment name filter
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
List of ExperimentInfo objects
|
|
166
|
+
"""
|
|
167
|
+
# Specific experiment requested
|
|
168
|
+
if project_filter and experiment_filter:
|
|
169
|
+
exp_data = remote_client.get_experiment_graphql(project_filter, experiment_filter)
|
|
170
|
+
if exp_data:
|
|
171
|
+
return [_experiment_from_graphql(exp_data)]
|
|
172
|
+
return []
|
|
173
|
+
|
|
174
|
+
# Check if project_filter contains glob wildcards
|
|
175
|
+
has_wildcards = project_filter and any(c in project_filter for c in ['*', '?', '['])
|
|
176
|
+
|
|
177
|
+
if has_wildcards:
|
|
178
|
+
# Use searchExperiments for glob patterns
|
|
179
|
+
# Expand simple project patterns to full namespace/project/experiment format
|
|
180
|
+
if '/' not in project_filter:
|
|
181
|
+
# Simple project pattern: "tut*" -> "*/tut*/*"
|
|
182
|
+
search_pattern = f"*/{project_filter}/*"
|
|
183
|
+
else:
|
|
184
|
+
# Full or partial pattern: use as-is
|
|
185
|
+
search_pattern = project_filter
|
|
186
|
+
|
|
187
|
+
experiments_data = remote_client.search_experiments_graphql(search_pattern)
|
|
188
|
+
return [_experiment_from_graphql(exp) for exp in experiments_data]
|
|
189
|
+
|
|
190
|
+
# Project filter - get all experiments in project (exact match)
|
|
191
|
+
if project_filter:
|
|
192
|
+
experiments_data = remote_client.list_experiments_graphql(project_filter)
|
|
193
|
+
return [_experiment_from_graphql(exp) for exp in experiments_data]
|
|
194
|
+
|
|
195
|
+
# No filter - get all projects and their experiments
|
|
196
|
+
projects = remote_client.list_projects_graphql()
|
|
197
|
+
all_experiments = []
|
|
198
|
+
for project in projects:
|
|
199
|
+
experiments_data = remote_client.list_experiments_graphql(project["slug"])
|
|
200
|
+
all_experiments.extend([_experiment_from_graphql(exp) for exp in experiments_data])
|
|
201
|
+
|
|
202
|
+
return all_experiments
|
|
177
203
|
|
|
178
204
|
|
|
179
205
|
# ============================================================================
|
|
180
206
|
# Experiment Downloader
|
|
181
207
|
# ============================================================================
|
|
182
208
|
|
|
183
|
-
class ExperimentDownloader:
|
|
184
|
-
"""Downloads a single experiment from remote server."""
|
|
185
|
-
|
|
186
|
-
def __init__(
|
|
187
|
-
self,
|
|
188
|
-
local_storage: LocalStorage,
|
|
189
|
-
remote_client: RemoteClient,
|
|
190
|
-
batch_size: int = 1000,
|
|
191
|
-
skip_logs: bool = False,
|
|
192
|
-
skip_metrics: bool = False,
|
|
193
|
-
skip_files: bool = False,
|
|
194
|
-
skip_params: bool = False,
|
|
195
|
-
verbose: bool = False,
|
|
196
|
-
max_concurrent_metrics: int = 5,
|
|
197
|
-
max_concurrent_files: int = 3,
|
|
198
|
-
):
|
|
199
|
-
self.local = local_storage
|
|
200
|
-
self.remote = remote_client
|
|
201
|
-
self.batch_size = batch_size
|
|
202
|
-
self.skip_logs = skip_logs
|
|
203
|
-
self.skip_metrics = skip_metrics
|
|
204
|
-
self.skip_files = skip_files
|
|
205
|
-
self.skip_params = skip_params
|
|
206
|
-
self.verbose = verbose
|
|
207
|
-
self.max_concurrent_metrics = max_concurrent_metrics
|
|
208
|
-
self.max_concurrent_files = max_concurrent_files
|
|
209
|
-
self._lock = threading.Lock()
|
|
210
|
-
self._thread_local = threading.local()
|
|
211
|
-
|
|
212
|
-
def _get_remote_client(self) -> RemoteClient:
|
|
213
|
-
"""Get thread-local remote client for safe concurrent access."""
|
|
214
|
-
if not hasattr(self._thread_local, 'client'):
|
|
215
|
-
self._thread_local.client = RemoteClient(
|
|
216
|
-
base_url=self.remote.base_url,
|
|
217
|
-
api_key=self.remote.api_key
|
|
218
|
-
)
|
|
219
|
-
return self._thread_local.client
|
|
220
209
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
result = DownloadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
|
|
210
|
+
class ExperimentDownloader:
|
|
211
|
+
"""Downloads a single experiment from remote server."""
|
|
224
212
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
local_storage: LocalStorage,
|
|
216
|
+
remote_client: RemoteClient,
|
|
217
|
+
batch_size: int = 1000,
|
|
218
|
+
skip_logs: bool = False,
|
|
219
|
+
skip_metrics: bool = False,
|
|
220
|
+
skip_files: bool = False,
|
|
221
|
+
skip_params: bool = False,
|
|
222
|
+
verbose: bool = False,
|
|
223
|
+
max_concurrent_metrics: int = 5,
|
|
224
|
+
max_concurrent_files: int = 3,
|
|
225
|
+
):
|
|
226
|
+
self.local = local_storage
|
|
227
|
+
self.remote = remote_client
|
|
228
|
+
self.batch_size = batch_size
|
|
229
|
+
self.skip_logs = skip_logs
|
|
230
|
+
self.skip_metrics = skip_metrics
|
|
231
|
+
self.skip_files = skip_files
|
|
232
|
+
self.skip_params = skip_params
|
|
233
|
+
self.verbose = verbose
|
|
234
|
+
self.max_concurrent_metrics = max_concurrent_metrics
|
|
235
|
+
self.max_concurrent_files = max_concurrent_files
|
|
236
|
+
self._lock = threading.Lock()
|
|
237
|
+
self._thread_local = threading.local()
|
|
238
|
+
|
|
239
|
+
def _get_remote_client(self) -> RemoteClient:
|
|
240
|
+
"""Get thread-local remote client for safe concurrent access."""
|
|
241
|
+
if not hasattr(self._thread_local, "client"):
|
|
242
|
+
self._thread_local.client = RemoteClient(
|
|
243
|
+
base_url=self.remote.base_url, api_key=self.remote.api_key
|
|
244
|
+
)
|
|
245
|
+
return self._thread_local.client
|
|
246
|
+
|
|
247
|
+
def download_experiment(self, exp_info: ExperimentInfo) -> DownloadResult:
|
|
248
|
+
"""Download a complete experiment."""
|
|
249
|
+
result = DownloadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
|
|
228
250
|
|
|
229
|
-
|
|
230
|
-
|
|
251
|
+
try:
|
|
252
|
+
if self.verbose:
|
|
253
|
+
console.print(
|
|
254
|
+
f" [dim]Downloading {exp_info.project}/{exp_info.experiment}[/dim]"
|
|
255
|
+
)
|
|
231
256
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
self._download_parameters(exp_info, result)
|
|
257
|
+
# Step 1: Download metadata and create experiment
|
|
258
|
+
self._download_metadata(exp_info, result)
|
|
235
259
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
260
|
+
# Step 2: Download parameters
|
|
261
|
+
if not self.skip_params and exp_info.has_params:
|
|
262
|
+
self._download_parameters(exp_info, result)
|
|
239
263
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
264
|
+
# Step 3: Download logs
|
|
265
|
+
if not self.skip_logs and exp_info.has_logs:
|
|
266
|
+
self._download_logs(exp_info, result)
|
|
243
267
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
268
|
+
# Step 4: Download metrics (parallel)
|
|
269
|
+
if not self.skip_metrics and exp_info.metric_names:
|
|
270
|
+
self._download_metrics(exp_info, result)
|
|
247
271
|
|
|
248
|
-
|
|
272
|
+
# Step 5: Download files (parallel)
|
|
273
|
+
if not self.skip_files and exp_info.file_count > 0:
|
|
274
|
+
self._download_files(exp_info, result)
|
|
249
275
|
|
|
250
|
-
|
|
251
|
-
result.success = False
|
|
252
|
-
result.errors.append(str(e))
|
|
253
|
-
if self.verbose:
|
|
254
|
-
console.print(f" [red]Error: {e}[/red]")
|
|
276
|
+
result.success = True
|
|
255
277
|
|
|
256
|
-
|
|
278
|
+
except Exception as e:
|
|
279
|
+
result.success = False
|
|
280
|
+
result.errors.append(str(e))
|
|
281
|
+
if self.verbose:
|
|
282
|
+
console.print(f" [red]Error: {e}[/red]")
|
|
283
|
+
|
|
284
|
+
return result
|
|
285
|
+
|
|
286
|
+
def _download_metadata(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
287
|
+
"""Download and create experiment metadata."""
|
|
288
|
+
# Create experiment directory structure with prefix
|
|
289
|
+
# The prefix from server already includes the full path (owner/project/.../ experiment_name)
|
|
290
|
+
if exp_info.prefix:
|
|
291
|
+
full_prefix = exp_info.prefix
|
|
292
|
+
else:
|
|
293
|
+
# Build prefix from owner/project/experiment
|
|
294
|
+
if exp_info.owner:
|
|
295
|
+
full_prefix = f"{exp_info.owner}/{exp_info.project}/{exp_info.experiment}"
|
|
296
|
+
else:
|
|
297
|
+
full_prefix = f"{exp_info.project}/{exp_info.experiment}"
|
|
298
|
+
|
|
299
|
+
# Extract owner from prefix or use from exp_info
|
|
300
|
+
owner = exp_info.owner if exp_info.owner else full_prefix.split('/')[0]
|
|
301
|
+
|
|
302
|
+
self.local.create_experiment(
|
|
303
|
+
owner=owner,
|
|
304
|
+
project=exp_info.project,
|
|
305
|
+
prefix=full_prefix,
|
|
306
|
+
description=exp_info.description,
|
|
307
|
+
tags=exp_info.tags,
|
|
308
|
+
bindrs=[],
|
|
309
|
+
metadata=None,
|
|
310
|
+
)
|
|
257
311
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
tags=exp_info.tags,
|
|
266
|
-
bindrs=[],
|
|
267
|
-
folder=exp_info.folder,
|
|
268
|
-
metadata=None,
|
|
312
|
+
def _download_parameters(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
313
|
+
"""Download parameters."""
|
|
314
|
+
try:
|
|
315
|
+
params_data = self.remote.get_parameters(exp_info.experiment_id)
|
|
316
|
+
if params_data:
|
|
317
|
+
self.local.write_parameters(
|
|
318
|
+
project=exp_info.project, experiment=exp_info.experiment, data=params_data
|
|
269
319
|
)
|
|
320
|
+
result.downloaded["parameters"] = 1
|
|
321
|
+
result.bytes_downloaded += len(json.dumps(params_data))
|
|
322
|
+
except Exception as e:
|
|
323
|
+
result.failed.setdefault("parameters", []).append(str(e))
|
|
270
324
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
def _download_logs(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
287
|
-
"""Download logs with pagination."""
|
|
288
|
-
try:
|
|
289
|
-
offset = 0
|
|
290
|
-
total_downloaded = 0
|
|
291
|
-
|
|
292
|
-
while True:
|
|
293
|
-
logs_data = self.remote.query_logs(
|
|
294
|
-
experiment_id=exp_info.experiment_id,
|
|
295
|
-
limit=self.batch_size,
|
|
296
|
-
offset=offset,
|
|
297
|
-
order_by="sequenceNumber",
|
|
298
|
-
order="asc"
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
logs = logs_data.get("logs", [])
|
|
302
|
-
if not logs:
|
|
303
|
-
break
|
|
304
|
-
|
|
305
|
-
# Write logs
|
|
306
|
-
for log in logs:
|
|
307
|
-
self.local.write_log(
|
|
308
|
-
project=exp_info.project,
|
|
309
|
-
experiment=exp_info.experiment,
|
|
310
|
-
message=log['message'],
|
|
311
|
-
level=log['level'],
|
|
312
|
-
timestamp=log['timestamp'],
|
|
313
|
-
metadata=log.get('metadata')
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
total_downloaded += len(logs)
|
|
317
|
-
result.bytes_downloaded += sum(len(json.dumps(log)) for log in logs)
|
|
318
|
-
|
|
319
|
-
if not logs_data.get("hasMore", False):
|
|
320
|
-
break
|
|
321
|
-
|
|
322
|
-
offset += len(logs)
|
|
323
|
-
|
|
324
|
-
result.downloaded["logs"] = total_downloaded
|
|
325
|
-
|
|
326
|
-
except Exception as e:
|
|
327
|
-
result.failed.setdefault("logs", []).append(str(e))
|
|
328
|
-
|
|
329
|
-
def _download_metrics(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
330
|
-
"""Download all metrics in parallel."""
|
|
331
|
-
with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
|
|
332
|
-
future_to_metric = {}
|
|
333
|
-
|
|
334
|
-
for metric_name in exp_info.metric_names:
|
|
335
|
-
future = executor.submit(
|
|
336
|
-
self._download_single_metric,
|
|
337
|
-
exp_info.experiment_id,
|
|
338
|
-
exp_info.project,
|
|
339
|
-
exp_info.experiment,
|
|
340
|
-
metric_name
|
|
341
|
-
)
|
|
342
|
-
future_to_metric[future] = metric_name
|
|
343
|
-
|
|
344
|
-
for future in as_completed(future_to_metric):
|
|
345
|
-
metric_name = future_to_metric[future]
|
|
346
|
-
metric_result = future.result()
|
|
347
|
-
|
|
348
|
-
with self._lock:
|
|
349
|
-
if metric_result['success']:
|
|
350
|
-
result.downloaded["metrics"] = result.downloaded.get("metrics", 0) + 1
|
|
351
|
-
result.bytes_downloaded += metric_result['bytes']
|
|
352
|
-
else:
|
|
353
|
-
result.failed.setdefault("metrics", []).append(
|
|
354
|
-
f"{metric_name}: {metric_result['error']}"
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
def _download_single_chunk(self, experiment_id: str, metric_name: str, chunk_number: int):
|
|
358
|
-
"""Download a single chunk (for parallel downloading)."""
|
|
359
|
-
remote = self._get_remote_client()
|
|
360
|
-
try:
|
|
361
|
-
chunk_data = remote.download_metric_chunk(experiment_id, metric_name, chunk_number)
|
|
362
|
-
return {
|
|
363
|
-
'success': True,
|
|
364
|
-
'chunk_number': chunk_number,
|
|
365
|
-
'data': chunk_data.get('data', []),
|
|
366
|
-
'start_index': int(chunk_data.get('startIndex', 0)),
|
|
367
|
-
'end_index': int(chunk_data.get('endIndex', 0)),
|
|
368
|
-
'error': None
|
|
369
|
-
}
|
|
370
|
-
except Exception as e:
|
|
371
|
-
return {
|
|
372
|
-
'success': False,
|
|
373
|
-
'chunk_number': chunk_number,
|
|
374
|
-
'error': str(e),
|
|
375
|
-
'data': []
|
|
376
|
-
}
|
|
377
|
-
|
|
378
|
-
def _download_single_metric(self, experiment_id: str, project: str, experiment: str, metric_name: str):
|
|
379
|
-
"""Download a single metric using chunk-aware approach (thread-safe)."""
|
|
380
|
-
remote = self._get_remote_client()
|
|
381
|
-
|
|
382
|
-
total_downloaded = 0
|
|
383
|
-
bytes_downloaded = 0
|
|
384
|
-
|
|
385
|
-
try:
|
|
386
|
-
# Get metric metadata to determine download strategy
|
|
387
|
-
metadata = remote.get_metric_stats(experiment_id, metric_name)
|
|
388
|
-
total_chunks = metadata.get('totalChunks', 0)
|
|
389
|
-
buffered_points = int(metadata.get('bufferedDataPoints', 0))
|
|
390
|
-
|
|
391
|
-
all_data = []
|
|
392
|
-
|
|
393
|
-
# Download chunks in parallel if they exist
|
|
394
|
-
if total_chunks > 0:
|
|
395
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
396
|
-
|
|
397
|
-
# Download all chunks in parallel (max 10 workers)
|
|
398
|
-
with ThreadPoolExecutor(max_workers=min(10, total_chunks)) as executor:
|
|
399
|
-
chunk_futures = {
|
|
400
|
-
executor.submit(self._download_single_chunk, experiment_id, metric_name, i): i
|
|
401
|
-
for i in range(total_chunks)
|
|
402
|
-
}
|
|
403
|
-
|
|
404
|
-
for future in as_completed(chunk_futures):
|
|
405
|
-
result = future.result()
|
|
406
|
-
if result['success']:
|
|
407
|
-
all_data.extend(result['data'])
|
|
408
|
-
else:
|
|
409
|
-
# If a chunk fails, fall back to pagination
|
|
410
|
-
raise Exception(f"Chunk {result['chunk_number']} download failed: {result['error']}")
|
|
411
|
-
|
|
412
|
-
# Download buffer data if exists
|
|
413
|
-
if buffered_points > 0:
|
|
414
|
-
response = remote.get_metric_data(
|
|
415
|
-
experiment_id, metric_name,
|
|
416
|
-
buffer_only=True
|
|
417
|
-
)
|
|
418
|
-
buffer_data = response.get('data', [])
|
|
419
|
-
all_data.extend(buffer_data)
|
|
420
|
-
|
|
421
|
-
# Sort all data by index
|
|
422
|
-
all_data.sort(key=lambda x: int(x.get('index', 0)))
|
|
423
|
-
|
|
424
|
-
# Write to local storage in batches
|
|
425
|
-
batch_size = 10000
|
|
426
|
-
for i in range(0, len(all_data), batch_size):
|
|
427
|
-
batch = all_data[i:i + batch_size]
|
|
428
|
-
self.local.append_batch_to_metric(
|
|
429
|
-
project, experiment, metric_name,
|
|
430
|
-
data_points=[d['data'] for d in batch]
|
|
431
|
-
)
|
|
432
|
-
total_downloaded += len(batch)
|
|
433
|
-
bytes_downloaded += sum(len(json.dumps(d)) for d in batch)
|
|
434
|
-
|
|
435
|
-
return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
|
|
436
|
-
|
|
437
|
-
except Exception as e:
|
|
438
|
-
# Fall back to pagination if chunk download fails
|
|
439
|
-
console.print(f"[yellow]Chunk download failed for {metric_name}, falling back to pagination: {e}[/yellow]")
|
|
440
|
-
return self._download_metric_with_pagination(experiment_id, project, experiment, metric_name)
|
|
441
|
-
|
|
442
|
-
def _download_metric_with_pagination(self, experiment_id: str, project: str, experiment: str, metric_name: str):
|
|
443
|
-
"""Original pagination-based download (fallback method)."""
|
|
444
|
-
remote = self._get_remote_client()
|
|
445
|
-
|
|
446
|
-
total_downloaded = 0
|
|
447
|
-
bytes_downloaded = 0
|
|
448
|
-
start_index = 0
|
|
449
|
-
|
|
450
|
-
try:
|
|
451
|
-
while True:
|
|
452
|
-
response = remote.get_metric_data(
|
|
453
|
-
experiment_id, metric_name,
|
|
454
|
-
start_index=start_index,
|
|
455
|
-
limit=self.batch_size
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
data_points = response.get('data', [])
|
|
459
|
-
if not data_points:
|
|
460
|
-
break
|
|
461
|
-
|
|
462
|
-
# Write to local storage
|
|
463
|
-
self.local.append_batch_to_metric(
|
|
464
|
-
project, experiment, metric_name,
|
|
465
|
-
data_points=[d['data'] for d in data_points]
|
|
466
|
-
)
|
|
467
|
-
|
|
468
|
-
total_downloaded += len(data_points)
|
|
469
|
-
bytes_downloaded += sum(len(json.dumps(d)) for d in data_points)
|
|
470
|
-
|
|
471
|
-
if not response.get('hasMore', False):
|
|
472
|
-
break
|
|
473
|
-
|
|
474
|
-
start_index += len(data_points)
|
|
475
|
-
|
|
476
|
-
return {'success': True, 'downloaded': total_downloaded, 'bytes': bytes_downloaded, 'error': None}
|
|
477
|
-
|
|
478
|
-
except Exception as e:
|
|
479
|
-
return {'success': False, 'error': str(e), 'downloaded': 0, 'bytes': 0}
|
|
480
|
-
|
|
481
|
-
def _download_files(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
482
|
-
"""Download files in parallel."""
|
|
483
|
-
# Get file list
|
|
484
|
-
try:
|
|
485
|
-
files_data = self.remote.list_files(exp_info.experiment_id)
|
|
486
|
-
except Exception as e:
|
|
487
|
-
result.failed.setdefault("files", []).append(f"List files failed: {e}")
|
|
488
|
-
return
|
|
489
|
-
|
|
490
|
-
if not files_data:
|
|
491
|
-
return
|
|
492
|
-
|
|
493
|
-
with ThreadPoolExecutor(max_workers=self.max_concurrent_files) as executor:
|
|
494
|
-
future_to_file = {}
|
|
495
|
-
|
|
496
|
-
for file_info in files_data:
|
|
497
|
-
future = executor.submit(
|
|
498
|
-
self._download_single_file,
|
|
499
|
-
exp_info.experiment_id,
|
|
500
|
-
exp_info.project,
|
|
501
|
-
exp_info.experiment,
|
|
502
|
-
file_info
|
|
503
|
-
)
|
|
504
|
-
future_to_file[future] = file_info['filename']
|
|
505
|
-
|
|
506
|
-
for future in as_completed(future_to_file):
|
|
507
|
-
filename = future_to_file[future]
|
|
508
|
-
file_result = future.result()
|
|
509
|
-
|
|
510
|
-
with self._lock:
|
|
511
|
-
if file_result['success']:
|
|
512
|
-
result.downloaded["files"] = result.downloaded.get("files", 0) + 1
|
|
513
|
-
result.bytes_downloaded += file_result['bytes']
|
|
514
|
-
else:
|
|
515
|
-
result.failed.setdefault("files", []).append(
|
|
516
|
-
f"{filename}: {file_result['error']}"
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
def _download_single_file(self, experiment_id: str, project: str, experiment: str, file_info: Dict[str, Any]):
|
|
520
|
-
"""Download a single file with streaming (thread-safe)."""
|
|
521
|
-
remote = self._get_remote_client()
|
|
522
|
-
|
|
523
|
-
try:
|
|
524
|
-
# Stream download to temp file
|
|
525
|
-
temp_fd, temp_path = tempfile.mkstemp(prefix="ml_dash_download_")
|
|
526
|
-
os.close(temp_fd)
|
|
527
|
-
|
|
528
|
-
remote.download_file_streaming(
|
|
529
|
-
experiment_id, file_info['id'], dest_path=temp_path
|
|
530
|
-
)
|
|
325
|
+
def _download_logs(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
326
|
+
"""Download logs with pagination."""
|
|
327
|
+
try:
|
|
328
|
+
offset = 0
|
|
329
|
+
total_downloaded = 0
|
|
330
|
+
|
|
331
|
+
while True:
|
|
332
|
+
logs_data = self.remote.query_logs(
|
|
333
|
+
experiment_id=exp_info.experiment_id,
|
|
334
|
+
limit=self.batch_size,
|
|
335
|
+
offset=offset,
|
|
336
|
+
order_by="sequenceNumber",
|
|
337
|
+
order="asc",
|
|
338
|
+
)
|
|
531
339
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
experiment=experiment,
|
|
536
|
-
file_path=temp_path,
|
|
537
|
-
prefix=file_info['path'],
|
|
538
|
-
filename=file_info['filename'],
|
|
539
|
-
description=file_info.get('description'),
|
|
540
|
-
tags=file_info.get('tags', []),
|
|
541
|
-
metadata=file_info.get('metadata'),
|
|
542
|
-
checksum=file_info['checksum'],
|
|
543
|
-
content_type=file_info['contentType'],
|
|
544
|
-
size_bytes=file_info['sizeBytes']
|
|
545
|
-
)
|
|
340
|
+
logs = logs_data.get("logs", [])
|
|
341
|
+
if not logs:
|
|
342
|
+
break
|
|
546
343
|
|
|
547
|
-
|
|
548
|
-
|
|
344
|
+
# Write logs
|
|
345
|
+
for log in logs:
|
|
346
|
+
self.local.write_log(
|
|
347
|
+
project=exp_info.project,
|
|
348
|
+
experiment=exp_info.experiment,
|
|
349
|
+
message=log["message"],
|
|
350
|
+
level=log["level"],
|
|
351
|
+
timestamp=log["timestamp"],
|
|
352
|
+
metadata=log.get("metadata"),
|
|
353
|
+
)
|
|
549
354
|
|
|
550
|
-
|
|
355
|
+
total_downloaded += len(logs)
|
|
356
|
+
result.bytes_downloaded += sum(len(json.dumps(log)) for log in logs)
|
|
551
357
|
|
|
552
|
-
|
|
553
|
-
|
|
358
|
+
if not logs_data.get("hasMore", False):
|
|
359
|
+
break
|
|
554
360
|
|
|
361
|
+
offset += len(logs)
|
|
555
362
|
|
|
556
|
-
|
|
557
|
-
# Main Command
|
|
558
|
-
# ============================================================================
|
|
363
|
+
result.downloaded["logs"] = total_downloaded
|
|
559
364
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
# Load or create state
|
|
577
|
-
state_file = Path(args.state_file)
|
|
578
|
-
if args.resume:
|
|
579
|
-
state = DownloadState.load(state_file)
|
|
580
|
-
if state:
|
|
581
|
-
console.print(f"[cyan]Resuming from previous download ({len(state.completed_experiments)} completed)[/cyan]")
|
|
582
|
-
else:
|
|
583
|
-
console.print("[yellow]No previous state found, starting fresh[/yellow]")
|
|
584
|
-
state = DownloadState(
|
|
585
|
-
remote_url=remote_url,
|
|
586
|
-
local_path=str(args.path)
|
|
587
|
-
)
|
|
588
|
-
else:
|
|
589
|
-
state = DownloadState(
|
|
590
|
-
remote_url=remote_url,
|
|
591
|
-
local_path=str(args.path)
|
|
365
|
+
except Exception as e:
|
|
366
|
+
result.failed.setdefault("logs", []).append(str(e))
|
|
367
|
+
|
|
368
|
+
def _download_metrics(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
369
|
+
"""Download all metrics in parallel."""
|
|
370
|
+
with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
|
|
371
|
+
future_to_metric = {}
|
|
372
|
+
|
|
373
|
+
for metric_name in exp_info.metric_names:
|
|
374
|
+
future = executor.submit(
|
|
375
|
+
self._download_single_metric,
|
|
376
|
+
exp_info.experiment_id,
|
|
377
|
+
exp_info.project,
|
|
378
|
+
exp_info.experiment,
|
|
379
|
+
metric_name,
|
|
592
380
|
)
|
|
381
|
+
future_to_metric[future] = metric_name
|
|
382
|
+
|
|
383
|
+
for future in as_completed(future_to_metric):
|
|
384
|
+
metric_name = future_to_metric[future]
|
|
385
|
+
metric_result = future.result()
|
|
386
|
+
|
|
387
|
+
with self._lock:
|
|
388
|
+
if metric_result["success"]:
|
|
389
|
+
result.downloaded["metrics"] = result.downloaded.get("metrics", 0) + 1
|
|
390
|
+
result.bytes_downloaded += metric_result["bytes"]
|
|
391
|
+
else:
|
|
392
|
+
result.failed.setdefault("metrics", []).append(
|
|
393
|
+
f"{metric_name}: {metric_result['error']}"
|
|
394
|
+
)
|
|
593
395
|
|
|
594
|
-
|
|
595
|
-
|
|
396
|
+
def _download_single_chunk(
|
|
397
|
+
self, experiment_id: str, metric_name: str, chunk_number: int
|
|
398
|
+
):
|
|
399
|
+
"""Download a single chunk (for parallel downloading)."""
|
|
400
|
+
remote = self._get_remote_client()
|
|
596
401
|
try:
|
|
597
|
-
|
|
598
|
-
|
|
402
|
+
chunk_data = remote.download_metric_chunk(
|
|
403
|
+
experiment_id, metric_name, chunk_number
|
|
404
|
+
)
|
|
405
|
+
return {
|
|
406
|
+
"success": True,
|
|
407
|
+
"chunk_number": chunk_number,
|
|
408
|
+
"data": chunk_data.get("data", []),
|
|
409
|
+
"start_index": int(chunk_data.get("startIndex", 0)),
|
|
410
|
+
"end_index": int(chunk_data.get("endIndex", 0)),
|
|
411
|
+
"error": None,
|
|
412
|
+
}
|
|
413
|
+
except Exception as e:
|
|
414
|
+
return {
|
|
415
|
+
"success": False,
|
|
416
|
+
"chunk_number": chunk_number,
|
|
417
|
+
"error": str(e),
|
|
418
|
+
"data": [],
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
def _download_single_metric(
|
|
422
|
+
self, experiment_id: str, project: str, experiment: str, metric_name: str
|
|
423
|
+
):
|
|
424
|
+
"""Download a single metric using chunk-aware approach (thread-safe)."""
|
|
425
|
+
remote = self._get_remote_client()
|
|
426
|
+
|
|
427
|
+
total_downloaded = 0
|
|
428
|
+
bytes_downloaded = 0
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
# Get metric metadata to determine download strategy
|
|
432
|
+
metadata = remote.get_metric_stats(experiment_id, metric_name)
|
|
433
|
+
total_chunks = metadata.get("totalChunks", 0)
|
|
434
|
+
buffered_points = int(metadata.get("bufferedDataPoints", 0))
|
|
435
|
+
|
|
436
|
+
all_data = []
|
|
437
|
+
|
|
438
|
+
# Download chunks in parallel if they exist
|
|
439
|
+
if total_chunks > 0:
|
|
440
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
441
|
+
|
|
442
|
+
# Download all chunks in parallel (max 10 workers)
|
|
443
|
+
with ThreadPoolExecutor(max_workers=min(10, total_chunks)) as executor:
|
|
444
|
+
chunk_futures = {
|
|
445
|
+
executor.submit(
|
|
446
|
+
self._download_single_chunk, experiment_id, metric_name, i
|
|
447
|
+
): i
|
|
448
|
+
for i in range(total_chunks)
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
for future in as_completed(chunk_futures):
|
|
452
|
+
result = future.result()
|
|
453
|
+
if result["success"]:
|
|
454
|
+
all_data.extend(result["data"])
|
|
455
|
+
else:
|
|
456
|
+
# If a chunk fails, fall back to pagination
|
|
457
|
+
raise Exception(
|
|
458
|
+
f"Chunk {result['chunk_number']} download failed: {result['error']}"
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Download buffer data if exists
|
|
462
|
+
if buffered_points > 0:
|
|
463
|
+
response = remote.get_metric_data(experiment_id, metric_name, buffer_only=True)
|
|
464
|
+
buffer_data = response.get("data", [])
|
|
465
|
+
all_data.extend(buffer_data)
|
|
466
|
+
|
|
467
|
+
# Sort all data by index
|
|
468
|
+
all_data.sort(key=lambda x: int(x.get("index", 0)))
|
|
469
|
+
|
|
470
|
+
# Write to local storage in batches
|
|
471
|
+
batch_size = 10000
|
|
472
|
+
for i in range(0, len(all_data), batch_size):
|
|
473
|
+
batch = all_data[i : i + batch_size]
|
|
474
|
+
self.local.append_batch_to_metric(
|
|
475
|
+
project, experiment, metric_name, data_points=[d["data"] for d in batch]
|
|
599
476
|
)
|
|
477
|
+
total_downloaded += len(batch)
|
|
478
|
+
bytes_downloaded += sum(len(json.dumps(d)) for d in batch)
|
|
479
|
+
|
|
480
|
+
return {
|
|
481
|
+
"success": True,
|
|
482
|
+
"downloaded": total_downloaded,
|
|
483
|
+
"bytes": bytes_downloaded,
|
|
484
|
+
"error": None,
|
|
485
|
+
}
|
|
486
|
+
|
|
600
487
|
except Exception as e:
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
console.print(f" [yellow]Skipping {exp_key} (already exists locally)[/yellow]")
|
|
624
|
-
continue
|
|
625
|
-
|
|
626
|
-
experiments_to_download.append(exp)
|
|
627
|
-
|
|
628
|
-
if not experiments_to_download:
|
|
629
|
-
console.print("[green]All experiments already downloaded[/green]")
|
|
630
|
-
return 0
|
|
631
|
-
|
|
632
|
-
# Dry run mode
|
|
633
|
-
if args.dry_run:
|
|
634
|
-
console.print("\n[bold]Dry run - would download:[/bold]")
|
|
635
|
-
for exp in experiments_to_download:
|
|
636
|
-
console.print(f" • {exp.project}/{exp.experiment}")
|
|
637
|
-
console.print(f" Logs: {exp.log_count}, Metrics: {len(exp.metric_names)}, Files: {exp.file_count}")
|
|
638
|
-
return 0
|
|
639
|
-
|
|
640
|
-
# Download experiments
|
|
641
|
-
console.print(f"\n[bold]Downloading {len(experiments_to_download)} experiment(s)...[/bold]")
|
|
642
|
-
results = []
|
|
643
|
-
start_time = time.time()
|
|
644
|
-
|
|
645
|
-
for i, exp in enumerate(experiments_to_download, 1):
|
|
646
|
-
exp_key = f"{exp.project}/{exp.experiment}"
|
|
647
|
-
console.print(f"\n[cyan][{i}/{len(experiments_to_download)}] {exp_key}[/cyan]")
|
|
648
|
-
|
|
649
|
-
# Mark as in-progress
|
|
650
|
-
state.in_progress_experiment = exp_key
|
|
651
|
-
state.save(state_file)
|
|
652
|
-
|
|
653
|
-
# Download
|
|
654
|
-
downloader = ExperimentDownloader(
|
|
655
|
-
local_storage=local_storage,
|
|
656
|
-
remote_client=remote_client,
|
|
657
|
-
batch_size=args.batch_size,
|
|
658
|
-
skip_logs=args.skip_logs,
|
|
659
|
-
skip_metrics=args.skip_metrics,
|
|
660
|
-
skip_files=args.skip_files,
|
|
661
|
-
skip_params=args.skip_params,
|
|
662
|
-
verbose=args.verbose,
|
|
663
|
-
max_concurrent_metrics=args.max_concurrent_metrics,
|
|
664
|
-
max_concurrent_files=args.max_concurrent_files,
|
|
488
|
+
# Fall back to pagination if chunk download fails
|
|
489
|
+
console.print(
|
|
490
|
+
f"[yellow]Chunk download failed for {metric_name}, falling back to pagination: {e}[/yellow]"
|
|
491
|
+
)
|
|
492
|
+
return self._download_metric_with_pagination(
|
|
493
|
+
experiment_id, project, experiment, metric_name
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def _download_metric_with_pagination(
|
|
497
|
+
self, experiment_id: str, project: str, experiment: str, metric_name: str
|
|
498
|
+
):
|
|
499
|
+
"""Original pagination-based download (fallback method)."""
|
|
500
|
+
remote = self._get_remote_client()
|
|
501
|
+
|
|
502
|
+
total_downloaded = 0
|
|
503
|
+
bytes_downloaded = 0
|
|
504
|
+
start_index = 0
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
while True:
|
|
508
|
+
response = remote.get_metric_data(
|
|
509
|
+
experiment_id, metric_name, start_index=start_index, limit=self.batch_size
|
|
665
510
|
)
|
|
666
511
|
|
|
667
|
-
|
|
668
|
-
|
|
512
|
+
data_points = response.get("data", [])
|
|
513
|
+
if not data_points:
|
|
514
|
+
break
|
|
669
515
|
|
|
670
|
-
#
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
else:
|
|
675
|
-
state.failed_experiments.append(exp_key)
|
|
676
|
-
console.print(f" [red]✗ Failed: {', '.join(result.errors)}[/red]")
|
|
516
|
+
# Write to local storage
|
|
517
|
+
self.local.append_batch_to_metric(
|
|
518
|
+
project, experiment, metric_name, data_points=[d["data"] for d in data_points]
|
|
519
|
+
)
|
|
677
520
|
|
|
678
|
-
|
|
679
|
-
|
|
521
|
+
total_downloaded += len(data_points)
|
|
522
|
+
bytes_downloaded += sum(len(json.dumps(d)) for d in data_points)
|
|
680
523
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
elapsed_time = end_time - start_time
|
|
684
|
-
total_bytes = sum(r.bytes_downloaded for r in results)
|
|
685
|
-
successful = sum(1 for r in results if r.success)
|
|
524
|
+
if not response.get("hasMore", False):
|
|
525
|
+
break
|
|
686
526
|
|
|
687
|
-
|
|
688
|
-
summary_table = Table()
|
|
689
|
-
summary_table.add_column("Metric", style="cyan")
|
|
690
|
-
summary_table.add_column("Value", style="green")
|
|
527
|
+
start_index += len(data_points)
|
|
691
528
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
529
|
+
return {
|
|
530
|
+
"success": True,
|
|
531
|
+
"downloaded": total_downloaded,
|
|
532
|
+
"bytes": bytes_downloaded,
|
|
533
|
+
"error": None,
|
|
534
|
+
}
|
|
697
535
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
summary_table.add_row("Avg Speed", _format_bytes_per_sec(speed))
|
|
536
|
+
except Exception as e:
|
|
537
|
+
return {"success": False, "error": str(e), "downloaded": 0, "bytes": 0}
|
|
701
538
|
|
|
702
|
-
|
|
539
|
+
def _download_files(self, exp_info: ExperimentInfo, result: DownloadResult):
|
|
540
|
+
"""Download files in parallel."""
|
|
541
|
+
# Get file list
|
|
542
|
+
try:
|
|
543
|
+
files_data = self.remote.list_files(exp_info.experiment_id)
|
|
544
|
+
except Exception as e:
|
|
545
|
+
result.failed.setdefault("files", []).append(f"List files failed: {e}")
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
if not files_data:
|
|
549
|
+
return
|
|
550
|
+
|
|
551
|
+
with ThreadPoolExecutor(max_workers=self.max_concurrent_files) as executor:
|
|
552
|
+
future_to_file = {}
|
|
553
|
+
|
|
554
|
+
for file_info in files_data:
|
|
555
|
+
future = executor.submit(
|
|
556
|
+
self._download_single_file,
|
|
557
|
+
exp_info.experiment_id,
|
|
558
|
+
exp_info.project,
|
|
559
|
+
exp_info.experiment,
|
|
560
|
+
file_info,
|
|
561
|
+
)
|
|
562
|
+
future_to_file[future] = file_info["filename"]
|
|
563
|
+
|
|
564
|
+
for future in as_completed(future_to_file):
|
|
565
|
+
filename = future_to_file[future]
|
|
566
|
+
file_result = future.result()
|
|
567
|
+
|
|
568
|
+
with self._lock:
|
|
569
|
+
if file_result["success"]:
|
|
570
|
+
result.downloaded["files"] = result.downloaded.get("files", 0) + 1
|
|
571
|
+
result.bytes_downloaded += file_result["bytes"]
|
|
572
|
+
else:
|
|
573
|
+
result.failed.setdefault("files", []).append(
|
|
574
|
+
f"{filename}: {file_result['error']}"
|
|
575
|
+
)
|
|
703
576
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
577
|
+
def _download_single_file(
|
|
578
|
+
self, experiment_id: str, project: str, experiment: str, file_info: Dict[str, Any]
|
|
579
|
+
):
|
|
580
|
+
"""Download a single file with streaming (thread-safe)."""
|
|
581
|
+
remote = self._get_remote_client()
|
|
707
582
|
|
|
708
|
-
|
|
583
|
+
try:
|
|
584
|
+
# Stream download to temp file
|
|
585
|
+
temp_fd, temp_path = tempfile.mkstemp(prefix="ml_dash_download_")
|
|
586
|
+
os.close(temp_fd)
|
|
587
|
+
|
|
588
|
+
remote.download_file_streaming(
|
|
589
|
+
experiment_id, file_info["id"], dest_path=temp_path
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Write to local storage
|
|
593
|
+
self.local.write_file(
|
|
594
|
+
project=project,
|
|
595
|
+
experiment=experiment,
|
|
596
|
+
file_path=temp_path,
|
|
597
|
+
prefix=file_info["path"],
|
|
598
|
+
filename=file_info["filename"],
|
|
599
|
+
description=file_info.get("description"),
|
|
600
|
+
tags=file_info.get("tags", []),
|
|
601
|
+
metadata=file_info.get("metadata"),
|
|
602
|
+
checksum=file_info["checksum"],
|
|
603
|
+
content_type=file_info["contentType"],
|
|
604
|
+
size_bytes=file_info["sizeBytes"],
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
# Clean up temp file
|
|
608
|
+
os.remove(temp_path)
|
|
609
|
+
|
|
610
|
+
return {"success": True, "bytes": file_info["sizeBytes"], "error": None}
|
|
709
611
|
|
|
612
|
+
except Exception as e:
|
|
613
|
+
return {"success": False, "error": str(e), "bytes": 0}
|
|
710
614
|
|
|
711
|
-
def add_parser(subparsers):
|
|
712
|
-
"""Add download command parser."""
|
|
713
|
-
parser = subparsers.add_parser(
|
|
714
|
-
"download",
|
|
715
|
-
help="Download experiments from remote server to local storage"
|
|
716
|
-
)
|
|
717
615
|
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
nargs="?",
|
|
722
|
-
default="./.ml-dash",
|
|
723
|
-
help="Local storage directory (default: ./.ml-dash)"
|
|
724
|
-
)
|
|
616
|
+
# ============================================================================
|
|
617
|
+
# Main Command
|
|
618
|
+
# ============================================================================
|
|
725
619
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
620
|
+
|
|
621
|
+
def cmd_download(args: argparse.Namespace) -> int:
|
|
622
|
+
"""Execute download command."""
|
|
623
|
+
# Load configuration
|
|
624
|
+
config = Config()
|
|
625
|
+
remote_url = args.dash_url or config.remote_url
|
|
626
|
+
api_key = args.api_key or config.api_key # RemoteClient will auto-load if None
|
|
627
|
+
|
|
628
|
+
# Validate inputs
|
|
629
|
+
if not remote_url:
|
|
630
|
+
console.print("[red]Error:[/red] --dash-url is required (or set in config)")
|
|
631
|
+
return 1
|
|
632
|
+
|
|
633
|
+
# Initialize clients (RemoteClient will auto-load token if api_key is None)
|
|
634
|
+
remote_client = RemoteClient(base_url=remote_url, api_key=api_key)
|
|
635
|
+
local_storage = LocalStorage(root_path=Path(args.path))
|
|
636
|
+
|
|
637
|
+
# Load or create state
|
|
638
|
+
state_file = Path(args.state_file)
|
|
639
|
+
if args.resume:
|
|
640
|
+
state = DownloadState.load(state_file)
|
|
641
|
+
if state:
|
|
642
|
+
console.print(
|
|
643
|
+
f"[cyan]Resuming from previous download ({len(state.completed_experiments)} completed)[/cyan]"
|
|
644
|
+
)
|
|
645
|
+
else:
|
|
646
|
+
console.print("[yellow]No previous state found, starting fresh[/yellow]")
|
|
647
|
+
state = DownloadState(remote_url=remote_url, local_path=str(args.path))
|
|
648
|
+
else:
|
|
649
|
+
state = DownloadState(remote_url=remote_url, local_path=str(args.path))
|
|
650
|
+
|
|
651
|
+
# Discover experiments
|
|
652
|
+
console.print("[bold]Discovering experiments on remote server...[/bold]")
|
|
653
|
+
try:
|
|
654
|
+
experiments = discover_experiments(remote_client, args.project, args.experiment)
|
|
655
|
+
except Exception as e:
|
|
656
|
+
console.print(f"[red]Failed to discover experiments: {e}[/red]")
|
|
657
|
+
return 1
|
|
658
|
+
|
|
659
|
+
if not experiments:
|
|
660
|
+
console.print("[yellow]No experiments found[/yellow]")
|
|
661
|
+
return 0
|
|
662
|
+
|
|
663
|
+
console.print(f"Found {len(experiments)} experiment(s)")
|
|
664
|
+
|
|
665
|
+
# Filter out completed experiments
|
|
666
|
+
experiments_to_download = []
|
|
667
|
+
for exp in experiments:
|
|
668
|
+
exp_key = f"{exp.project}/{exp.experiment}"
|
|
669
|
+
|
|
670
|
+
# Skip if already completed
|
|
671
|
+
if exp_key in state.completed_experiments and not args.overwrite:
|
|
672
|
+
console.print(f" [dim]Skipping {exp_key} (already completed)[/dim]")
|
|
673
|
+
continue
|
|
674
|
+
|
|
675
|
+
# Check if exists locally
|
|
676
|
+
exp_json = (
|
|
677
|
+
local_storage.root_path / exp.project / exp.experiment / "experiment.json"
|
|
760
678
|
)
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
679
|
+
if exp_json.exists() and not args.overwrite:
|
|
680
|
+
console.print(f" [yellow]Skipping {exp_key} (already exists locally)[/yellow]")
|
|
681
|
+
continue
|
|
682
|
+
|
|
683
|
+
experiments_to_download.append(exp)
|
|
684
|
+
|
|
685
|
+
if not experiments_to_download:
|
|
686
|
+
console.print("[green]All experiments already downloaded[/green]")
|
|
687
|
+
return 0
|
|
688
|
+
|
|
689
|
+
# Dry run mode
|
|
690
|
+
if args.dry_run:
|
|
691
|
+
console.print("\n[bold]Dry run - would download:[/bold]")
|
|
692
|
+
for exp in experiments_to_download:
|
|
693
|
+
console.print(f" • {exp.project}/{exp.experiment}")
|
|
694
|
+
console.print(
|
|
695
|
+
f" Logs: {exp.log_count}, Metrics: {len(exp.metric_names)}, Files: {exp.file_count}"
|
|
696
|
+
)
|
|
697
|
+
return 0
|
|
698
|
+
|
|
699
|
+
# Download experiments
|
|
700
|
+
console.print(
|
|
701
|
+
f"\n[bold]Downloading {len(experiments_to_download)} experiment(s)...[/bold]"
|
|
702
|
+
)
|
|
703
|
+
results = []
|
|
704
|
+
start_time = time.time()
|
|
705
|
+
|
|
706
|
+
for i, exp in enumerate(experiments_to_download, 1):
|
|
707
|
+
exp_key = f"{exp.project}/{exp.experiment}"
|
|
708
|
+
console.print(f"\n[cyan][{i}/{len(experiments_to_download)}] {exp_key}[/cyan]")
|
|
709
|
+
|
|
710
|
+
# Mark as in-progress
|
|
711
|
+
state.in_progress_experiment = exp_key
|
|
712
|
+
state.save(state_file)
|
|
713
|
+
|
|
714
|
+
# Download
|
|
715
|
+
downloader = ExperimentDownloader(
|
|
716
|
+
local_storage=local_storage,
|
|
717
|
+
remote_client=remote_client,
|
|
718
|
+
batch_size=args.batch_size,
|
|
719
|
+
skip_logs=args.skip_logs,
|
|
720
|
+
skip_metrics=args.skip_metrics,
|
|
721
|
+
skip_files=args.skip_files,
|
|
722
|
+
skip_params=args.skip_params,
|
|
723
|
+
verbose=args.verbose,
|
|
724
|
+
max_concurrent_metrics=args.max_concurrent_metrics,
|
|
725
|
+
max_concurrent_files=args.max_concurrent_files,
|
|
766
726
|
)
|
|
767
|
-
parser.add_argument("-v", "--verbose", action="store_true", help="Detailed progress output")
|
|
768
727
|
|
|
769
|
-
|
|
728
|
+
result = downloader.download_experiment(exp)
|
|
729
|
+
results.append(result)
|
|
730
|
+
|
|
731
|
+
# Update state
|
|
732
|
+
if result.success:
|
|
733
|
+
state.completed_experiments.append(exp_key)
|
|
734
|
+
console.print(" [green]✓ Downloaded successfully[/green]")
|
|
735
|
+
else:
|
|
736
|
+
state.failed_experiments.append(exp_key)
|
|
737
|
+
console.print(f" [red]✗ Failed: {', '.join(result.errors)}[/red]")
|
|
738
|
+
|
|
739
|
+
state.in_progress_experiment = None
|
|
740
|
+
state.save(state_file)
|
|
741
|
+
|
|
742
|
+
# Show summary
|
|
743
|
+
end_time = time.time()
|
|
744
|
+
elapsed_time = end_time - start_time
|
|
745
|
+
total_bytes = sum(r.bytes_downloaded for r in results)
|
|
746
|
+
successful = sum(1 for r in results if r.success)
|
|
747
|
+
|
|
748
|
+
console.print("\n[bold]Download Summary[/bold]")
|
|
749
|
+
summary_table = Table()
|
|
750
|
+
summary_table.add_column("Metric", style="cyan")
|
|
751
|
+
summary_table.add_column("Value", style="green")
|
|
752
|
+
|
|
753
|
+
summary_table.add_row("Total Experiments", str(len(results)))
|
|
754
|
+
summary_table.add_row("Successful", str(successful))
|
|
755
|
+
summary_table.add_row("Failed", str(len(results) - successful))
|
|
756
|
+
summary_table.add_row("Total Data", _format_bytes(total_bytes))
|
|
757
|
+
summary_table.add_row("Total Time", f"{elapsed_time:.2f}s")
|
|
758
|
+
|
|
759
|
+
if elapsed_time > 0:
|
|
760
|
+
speed = total_bytes / elapsed_time
|
|
761
|
+
summary_table.add_row("Avg Speed", _format_bytes_per_sec(speed))
|
|
762
|
+
|
|
763
|
+
console.print(summary_table)
|
|
764
|
+
|
|
765
|
+
# Clean up state file if all successful
|
|
766
|
+
if all(r.success for r in results):
|
|
767
|
+
state_file.unlink(missing_ok=True)
|
|
768
|
+
|
|
769
|
+
return 0 if all(r.success for r in results) else 1
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
def add_parser(subparsers):
|
|
773
|
+
"""Add download command parser."""
|
|
774
|
+
parser = subparsers.add_parser(
|
|
775
|
+
"download", help="Download experiments from remote server to local storage"
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
# Positional arguments
|
|
779
|
+
parser.add_argument(
|
|
780
|
+
"path",
|
|
781
|
+
nargs="?",
|
|
782
|
+
default="./.dash",
|
|
783
|
+
help="Local storage directory (default: ./.dash)",
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# Remote configuration
|
|
787
|
+
parser.add_argument(
|
|
788
|
+
"--dash-url",
|
|
789
|
+
help="ML-Dash server URL (defaults to config or https://api.dash.ml)",
|
|
790
|
+
)
|
|
791
|
+
parser.add_argument(
|
|
792
|
+
"--api-key",
|
|
793
|
+
help="JWT authentication token (optional - auto-loads from 'ml-dash login')",
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
# Scope control
|
|
797
|
+
parser.add_argument(
|
|
798
|
+
"-p",
|
|
799
|
+
"--pref",
|
|
800
|
+
"--prefix",
|
|
801
|
+
"--proj",
|
|
802
|
+
"--project",
|
|
803
|
+
dest="project",
|
|
804
|
+
type=str,
|
|
805
|
+
help="Filter experiments by project or pattern (supports glob: 'tut*', 'tom*/tutorials/*')"
|
|
806
|
+
)
|
|
807
|
+
parser.add_argument(
|
|
808
|
+
"--experiment",
|
|
809
|
+
help="Download specific experiment (requires --project)"
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Data filtering
|
|
813
|
+
parser.add_argument("--skip-logs", action="store_true", help="Don't download logs")
|
|
814
|
+
parser.add_argument(
|
|
815
|
+
"--skip-metrics", action="store_true", help="Don't download metrics"
|
|
816
|
+
)
|
|
817
|
+
parser.add_argument("--skip-files", action="store_true", help="Don't download files")
|
|
818
|
+
parser.add_argument(
|
|
819
|
+
"--skip-params", action="store_true", help="Don't download parameters"
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Behavior control
|
|
823
|
+
parser.add_argument(
|
|
824
|
+
"--dry-run", action="store_true", help="Preview without downloading"
|
|
825
|
+
)
|
|
826
|
+
parser.add_argument(
|
|
827
|
+
"--overwrite", action="store_true", help="Overwrite existing experiments"
|
|
828
|
+
)
|
|
829
|
+
parser.add_argument(
|
|
830
|
+
"--resume", action="store_true", help="Resume interrupted download"
|
|
831
|
+
)
|
|
832
|
+
parser.add_argument(
|
|
833
|
+
"--state-file",
|
|
834
|
+
default=".dash-download-state.json",
|
|
835
|
+
help="State file path for resume (default: .dash-download-state.json)",
|
|
836
|
+
)
|
|
837
|
+
parser.add_argument(
|
|
838
|
+
"--batch-size",
|
|
839
|
+
type=int,
|
|
840
|
+
default=1000,
|
|
841
|
+
help="Batch size for logs/metrics (default: 1000, max: 10000)",
|
|
842
|
+
)
|
|
843
|
+
parser.add_argument(
|
|
844
|
+
"--max-concurrent-metrics",
|
|
845
|
+
type=int,
|
|
846
|
+
default=5,
|
|
847
|
+
help="Parallel metric downloads (default: 5)",
|
|
848
|
+
)
|
|
849
|
+
parser.add_argument(
|
|
850
|
+
"--max-concurrent-files",
|
|
851
|
+
type=int,
|
|
852
|
+
default=3,
|
|
853
|
+
help="Parallel file downloads (default: 3)",
|
|
854
|
+
)
|
|
855
|
+
parser.add_argument(
|
|
856
|
+
"-v", "--verbose", action="store_true", help="Detailed progress output"
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
parser.set_defaults(func=cmd_download)
|