flyteplugins-wandb 2.0.0b52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/wandb/__init__.py +493 -0
- flyteplugins/wandb/_context.py +381 -0
- flyteplugins/wandb/_decorator.py +417 -0
- flyteplugins/wandb/_link.py +149 -0
- flyteplugins_wandb-2.0.0b52.dist-info/METADATA +34 -0
- flyteplugins_wandb-2.0.0b52.dist-info/RECORD +8 -0
- flyteplugins_wandb-2.0.0b52.dist-info/WHEEL +5 -0
- flyteplugins_wandb-2.0.0b52.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,493 @@
|
|
|
1
|
+
"""
|
|
2
|
+
## Key features:
|
|
3
|
+
|
|
4
|
+
- Automatic W&B run initialization with `@wandb_init` decorator
|
|
5
|
+
- Automatic W&B links in Flyte UI pointing to runs and sweeps
|
|
6
|
+
- Parent/child task support with automatic run reuse
|
|
7
|
+
- W&B sweep creation and management with `@wandb_sweep` decorator
|
|
8
|
+
- Configuration management with `wandb_config()` and `wandb_sweep_config()`
|
|
9
|
+
|
|
10
|
+
## Basic usage:
|
|
11
|
+
|
|
12
|
+
1. Simple task with W&B logging:
|
|
13
|
+
|
|
14
|
+
```python
|
|
15
|
+
from flyteplugins.wandb import wandb_init, get_wandb_run
|
|
16
|
+
|
|
17
|
+
@wandb_init(project="my-project", entity="my-team")
|
|
18
|
+
@env.task
|
|
19
|
+
async def train_model(learning_rate: float) -> str:
|
|
20
|
+
wandb_run = get_wandb_run()
|
|
21
|
+
wandb_run.log({"loss": 0.5, "learning_rate": learning_rate})
|
|
22
|
+
return wandb_run.id
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
2. Parent/Child Tasks with Run Reuse:
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
@wandb_init # Automatically reuses parent's run ID
|
|
29
|
+
@env.task
|
|
30
|
+
async def child_task(x: int) -> str:
|
|
31
|
+
wandb_run = get_wandb_run()
|
|
32
|
+
wandb_run.log({"child_metric": x * 2})
|
|
33
|
+
return wandb_run.id
|
|
34
|
+
|
|
35
|
+
@wandb_init(project="my-project", entity="my-team")
|
|
36
|
+
@env.task
|
|
37
|
+
async def parent_task() -> str:
|
|
38
|
+
wandb_run = get_wandb_run()
|
|
39
|
+
wandb_run.log({"parent_metric": 100})
|
|
40
|
+
|
|
41
|
+
# Child reuses parent's run by default (run_mode="auto")
|
|
42
|
+
await child_task(5)
|
|
43
|
+
return wandb_run.id
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
3. Configuration with context manager:
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
from flyteplugins.wandb import wandb_config
|
|
50
|
+
|
|
51
|
+
r = flyte.with_runcontext(
|
|
52
|
+
custom_context=wandb_config(
|
|
53
|
+
project="my-project",
|
|
54
|
+
entity="my-team",
|
|
55
|
+
tags=["experiment-1"]
|
|
56
|
+
)
|
|
57
|
+
).run(train_model, learning_rate=0.001)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
4. Creating new runs for child tasks:
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
@wandb_init(run_mode="new") # Always creates a new run
|
|
64
|
+
@env.task
|
|
65
|
+
async def independent_child() -> str:
|
|
66
|
+
wandb_run = get_wandb_run()
|
|
67
|
+
wandb_run.log({"independent_metric": 42})
|
|
68
|
+
return wandb_run.id
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
5. Running sweep agents in parallel:
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
import asyncio
|
|
75
|
+
from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id, get_wandb_context
|
|
76
|
+
|
|
77
|
+
@wandb_init
|
|
78
|
+
async def objective():
|
|
79
|
+
wandb_run = wandb.run
|
|
80
|
+
config = wandb_run.config
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
wandb_run.log({"loss": loss_value})
|
|
84
|
+
|
|
85
|
+
@wandb_sweep
|
|
86
|
+
@env.task
|
|
87
|
+
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
|
|
88
|
+
wandb.agent(sweep_id, function=objective, count=count, project=get_wandb_context().project)
|
|
89
|
+
return agent_id
|
|
90
|
+
|
|
91
|
+
@wandb_sweep
|
|
92
|
+
@env.task
|
|
93
|
+
async def run_parallel_sweep(num_agents: int = 2, trials_per_agent: int = 5) -> str:
|
|
94
|
+
sweep_id = get_wandb_sweep_id()
|
|
95
|
+
|
|
96
|
+
# Launch agents in parallel
|
|
97
|
+
agent_tasks = [
|
|
98
|
+
sweep_agent(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent)
|
|
99
|
+
for i in range(num_agents)
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# Wait for all agents to complete
|
|
103
|
+
await asyncio.gather(*agent_tasks)
|
|
104
|
+
return sweep_id
|
|
105
|
+
|
|
106
|
+
# Run with 2 parallel agents
|
|
107
|
+
r = flyte.with_runcontext(
|
|
108
|
+
custom_context={
|
|
109
|
+
**wandb_config(project="my-project", entity="my-team"),
|
|
110
|
+
**wandb_sweep_config(
|
|
111
|
+
method="random",
|
|
112
|
+
metric={"name": "loss", "goal": "minimize"},
|
|
113
|
+
parameters={
|
|
114
|
+
"learning_rate": {"min": 0.0001, "max": 0.1},
|
|
115
|
+
"batch_size": {"values": [16, 32, 64]},
|
|
116
|
+
}
|
|
117
|
+
)
|
|
118
|
+
}
|
|
119
|
+
).run(run_parallel_sweep, num_agents=2, trials_per_agent=5)
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Decorator order: `@wandb_init` or `@wandb_sweep` must be the outermost decorator:
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
@wandb_init
|
|
126
|
+
@env.task
|
|
127
|
+
async def my_task():
|
|
128
|
+
...
|
|
129
|
+
```
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
import json
|
|
133
|
+
import logging
|
|
134
|
+
import os
|
|
135
|
+
from typing import Optional
|
|
136
|
+
|
|
137
|
+
import flyte
|
|
138
|
+
from flyte.io import Dir
|
|
139
|
+
|
|
140
|
+
import wandb
|
|
141
|
+
|
|
142
|
+
from ._context import (
|
|
143
|
+
get_wandb_context,
|
|
144
|
+
get_wandb_sweep_context,
|
|
145
|
+
wandb_config,
|
|
146
|
+
wandb_sweep_config,
|
|
147
|
+
)
|
|
148
|
+
from ._decorator import wandb_init, wandb_sweep
|
|
149
|
+
from ._link import Wandb, WandbSweep
|
|
150
|
+
|
|
151
|
+
logger = logging.getLogger(__name__)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
__all__ = [
|
|
155
|
+
"Wandb",
|
|
156
|
+
"WandbSweep",
|
|
157
|
+
"download_wandb_run_dir",
|
|
158
|
+
"download_wandb_run_logs",
|
|
159
|
+
"download_wandb_sweep_dirs",
|
|
160
|
+
"download_wandb_sweep_logs",
|
|
161
|
+
"get_wandb_context",
|
|
162
|
+
"get_wandb_run",
|
|
163
|
+
"get_wandb_run_dir",
|
|
164
|
+
"get_wandb_sweep_context",
|
|
165
|
+
"get_wandb_sweep_id",
|
|
166
|
+
"wandb_config",
|
|
167
|
+
"wandb_init",
|
|
168
|
+
"wandb_sweep",
|
|
169
|
+
"wandb_sweep_config",
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
__version__ = "0.1.0"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_wandb_run():
|
|
177
|
+
"""
|
|
178
|
+
Get the current wandb run if within a `@wandb_init` decorated task or trace.
|
|
179
|
+
|
|
180
|
+
The run is initialized when the `@wandb_init` context manager is entered.
|
|
181
|
+
Returns None if not within a `wandb_init` context.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
`wandb.sdk.wandb_run.Run` | `None`: The current wandb run object or None.
|
|
185
|
+
"""
|
|
186
|
+
ctx = flyte.ctx()
|
|
187
|
+
if not ctx or not ctx.data:
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
return ctx.data.get("_wandb_run")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_wandb_sweep_id() -> str | None:
|
|
194
|
+
"""
|
|
195
|
+
Get the current wandb `sweep_id` if within a `@wandb_sweep` decorated task.
|
|
196
|
+
|
|
197
|
+
Returns `None` if not within a `wandb_sweep` context.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
`str` | `None`: The sweep ID or None.
|
|
201
|
+
"""
|
|
202
|
+
ctx = flyte.ctx()
|
|
203
|
+
if not ctx or not ctx.custom_context:
|
|
204
|
+
return None
|
|
205
|
+
|
|
206
|
+
return ctx.custom_context.get("_wandb_sweep_id")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_wandb_run_dir() -> Optional[str]:
|
|
210
|
+
"""
|
|
211
|
+
Get the local directory path for the current wandb run.
|
|
212
|
+
|
|
213
|
+
Use this for accessing files written by the current task without any
|
|
214
|
+
network calls. For accessing files from other tasks (or after a task
|
|
215
|
+
completes), use `download_wandb_run_dir()` instead.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Local path to wandb run directory (`wandb.run.dir`) or `None` if no
|
|
219
|
+
active run.
|
|
220
|
+
"""
|
|
221
|
+
run = get_wandb_run()
|
|
222
|
+
if run is None:
|
|
223
|
+
return None
|
|
224
|
+
return run.dir
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def download_wandb_run_dir(
|
|
228
|
+
run_id: Optional[str] = None,
|
|
229
|
+
path: Optional[str] = None,
|
|
230
|
+
include_history: bool = True,
|
|
231
|
+
) -> str:
|
|
232
|
+
"""
|
|
233
|
+
Download wandb run data from wandb cloud.
|
|
234
|
+
|
|
235
|
+
Downloads all run files and optionally exports metrics history to JSON.
|
|
236
|
+
This enables access to wandb data from any task or after workflow completion.
|
|
237
|
+
|
|
238
|
+
Downloaded contents:
|
|
239
|
+
|
|
240
|
+
- summary.json - final summary metrics (always exported)
|
|
241
|
+
- metrics_history.json - step-by-step metrics (if include_history=True)
|
|
242
|
+
- Plus any files synced by wandb (requirements.txt, wandb_metadata.json, etc.)
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
run_id: The wandb run ID to download. If `None`, uses the current run's ID
|
|
246
|
+
from context (useful for shared runs across tasks).
|
|
247
|
+
path: Local directory to download files to. If `None`, downloads to
|
|
248
|
+
`/tmp/wandb_runs/{run_id}`.
|
|
249
|
+
include_history: If `True`, exports the step-by-step metrics history
|
|
250
|
+
to `metrics_history.json`. Defaults to `True`.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Local path where files were downloaded.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
`RuntimeError`: If no `run_id` provided and no active run in context.
|
|
257
|
+
`wandb.errors.CommError`: If run not found in wandb cloud.
|
|
258
|
+
|
|
259
|
+
Note:
|
|
260
|
+
There may be a brief delay between when files are written locally and
|
|
261
|
+
when they're available in wandb cloud. For immediate local access
|
|
262
|
+
within the same task, use `get_wandb_run_dir()` instead.
|
|
263
|
+
"""
|
|
264
|
+
# Determine run_id
|
|
265
|
+
if run_id is None:
|
|
266
|
+
ctx = flyte.ctx()
|
|
267
|
+
if ctx and ctx.custom_context:
|
|
268
|
+
run_id = ctx.custom_context.get("_wandb_run_id")
|
|
269
|
+
if run_id is None:
|
|
270
|
+
run = get_wandb_run()
|
|
271
|
+
if run:
|
|
272
|
+
run_id = run.id
|
|
273
|
+
if run_id is None:
|
|
274
|
+
raise RuntimeError(
|
|
275
|
+
"No run_id provided and no active wandb run found in context. "
|
|
276
|
+
"Provide a run_id explicitly or call from within a @wandb_init task."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Get entity/project from context
|
|
280
|
+
wandb_ctx = get_wandb_context()
|
|
281
|
+
entity = wandb_ctx.entity if wandb_ctx else None
|
|
282
|
+
project = wandb_ctx.project if wandb_ctx else None
|
|
283
|
+
|
|
284
|
+
# Build run path for API
|
|
285
|
+
if entity and project:
|
|
286
|
+
run_path = f"{entity}/{project}/{run_id}"
|
|
287
|
+
elif project:
|
|
288
|
+
run_path = f"{project}/{run_id}"
|
|
289
|
+
else:
|
|
290
|
+
# wandb API can sometimes work with just run_id if logged in
|
|
291
|
+
run_path = run_id
|
|
292
|
+
|
|
293
|
+
# Set download path
|
|
294
|
+
if path is None:
|
|
295
|
+
path = f"/tmp/wandb_runs/{run_id}"
|
|
296
|
+
|
|
297
|
+
# Ensure directory exists
|
|
298
|
+
try:
|
|
299
|
+
os.makedirs(path, exist_ok=True)
|
|
300
|
+
except OSError as e:
|
|
301
|
+
raise RuntimeError(f"Failed to create download directory {path}: {e}") from e
|
|
302
|
+
|
|
303
|
+
# Download files from wandb cloud
|
|
304
|
+
try:
|
|
305
|
+
api = wandb.Api()
|
|
306
|
+
api_run = api.run(run_path)
|
|
307
|
+
except wandb.errors.AuthenticationError as e:
|
|
308
|
+
# Must check AuthenticationError before CommError (it's a subclass)
|
|
309
|
+
raise RuntimeError(
|
|
310
|
+
f"Authentication failed when accessing wandb run '{run_path}'. "
|
|
311
|
+
f"Please ensure WANDB_API_KEY is set correctly. Error: {e}"
|
|
312
|
+
) from e
|
|
313
|
+
except wandb.errors.CommError as e:
|
|
314
|
+
raise RuntimeError(
|
|
315
|
+
f"Failed to fetch wandb run '{run_path}' from wandb cloud. "
|
|
316
|
+
f"The run may not exist, or you may not have access to it. "
|
|
317
|
+
f"Error: {e}"
|
|
318
|
+
) from e
|
|
319
|
+
except Exception as e:
|
|
320
|
+
raise RuntimeError(f"Unexpected error fetching wandb run '{run_path}': {e}") from e
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
for file in api_run.files():
|
|
324
|
+
file.download(root=path, replace=True)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
raise RuntimeError(f"Failed to download files for run '{run_id}': {e}") from e
|
|
327
|
+
|
|
328
|
+
# Export summary to JSON
|
|
329
|
+
try:
|
|
330
|
+
summary_data = dict(api_run.summary)
|
|
331
|
+
if summary_data:
|
|
332
|
+
with open(os.path.join(path, "summary.json"), "w") as f:
|
|
333
|
+
json.dump(summary_data, f, indent=2, default=str)
|
|
334
|
+
except (IOError, OSError) as e:
|
|
335
|
+
raise RuntimeError(f"Failed to write summary.json for run '{run_id}': {e}") from e
|
|
336
|
+
except Exception as e:
|
|
337
|
+
raise RuntimeError(f"Failed to export summary data for run '{run_id}': {e}") from e
|
|
338
|
+
|
|
339
|
+
# Export metrics history to JSON
|
|
340
|
+
if include_history:
|
|
341
|
+
try:
|
|
342
|
+
history = api_run.history()
|
|
343
|
+
if history:
|
|
344
|
+
with open(os.path.join(path, "metrics_history.json"), "w") as f:
|
|
345
|
+
json.dump(history, f, indent=2, default=str)
|
|
346
|
+
except (IOError, OSError) as e:
|
|
347
|
+
raise RuntimeError(f"Failed to write metrics_history.json for run '{run_id}': {e}") from e
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise RuntimeError(f"Failed to export history data for run '{run_id}': {e}") from e
|
|
350
|
+
|
|
351
|
+
return path
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def download_wandb_sweep_dirs(
|
|
355
|
+
sweep_id: Optional[str] = None,
|
|
356
|
+
base_path: Optional[str] = None,
|
|
357
|
+
include_history: bool = True,
|
|
358
|
+
) -> list[str]:
|
|
359
|
+
"""
|
|
360
|
+
Download all run data for a wandb sweep.
|
|
361
|
+
|
|
362
|
+
Queries the wandb API for all runs in the sweep and downloads their files
|
|
363
|
+
and metrics history. This is useful for collecting results from all sweep
|
|
364
|
+
trials after completion.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
sweep_id: The wandb sweep ID. If `None`, uses the current sweep's ID
|
|
368
|
+
from context (set by `@wandb_sweep` decorator).
|
|
369
|
+
base_path: Base directory to download files to. Each run's files will be
|
|
370
|
+
in a subdirectory named by run_id. If `None`, uses `/tmp/wandb_runs/`.
|
|
371
|
+
include_history: If `True`, exports the step-by-step metrics history
|
|
372
|
+
to metrics_history.json for each run. Defaults to `True`.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
List of local paths where run data was downloaded.
|
|
376
|
+
|
|
377
|
+
Raises:
|
|
378
|
+
RuntimeError: If no sweep_id provided and no active sweep in context.
|
|
379
|
+
wandb.errors.CommError: If sweep not found in wandb cloud.
|
|
380
|
+
"""
|
|
381
|
+
# Determine sweep_id
|
|
382
|
+
if sweep_id is None:
|
|
383
|
+
sweep_id = get_wandb_sweep_id()
|
|
384
|
+
if sweep_id is None:
|
|
385
|
+
raise RuntimeError(
|
|
386
|
+
"No sweep_id provided and no active wandb sweep found in context. "
|
|
387
|
+
"Provide a sweep_id explicitly or call from within a @wandb_sweep task."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Get entity/project from context
|
|
391
|
+
wandb_ctx = get_wandb_context()
|
|
392
|
+
entity = wandb_ctx.entity if wandb_ctx else None
|
|
393
|
+
project = wandb_ctx.project if wandb_ctx else None
|
|
394
|
+
|
|
395
|
+
if not entity or not project:
|
|
396
|
+
raise RuntimeError("Cannot query sweep without entity and project. Set them via wandb_config().")
|
|
397
|
+
|
|
398
|
+
# Query sweep runs via wandb API
|
|
399
|
+
try:
|
|
400
|
+
api = wandb.Api()
|
|
401
|
+
sweep = api.sweep(f"{entity}/{project}/{sweep_id}")
|
|
402
|
+
run_ids = [run.id for run in sweep.runs]
|
|
403
|
+
except wandb.errors.AuthenticationError as e:
|
|
404
|
+
# Must check AuthenticationError before CommError (it's a subclass)
|
|
405
|
+
raise RuntimeError(
|
|
406
|
+
f"Authentication failed when accessing wandb sweep '{entity}/{project}/{sweep_id}'. "
|
|
407
|
+
f"Please ensure WANDB_API_KEY is set correctly. Error: {e}"
|
|
408
|
+
) from e
|
|
409
|
+
except wandb.errors.CommError as e:
|
|
410
|
+
raise RuntimeError(
|
|
411
|
+
f"Failed to fetch wandb sweep '{entity}/{project}/{sweep_id}' from wandb cloud. "
|
|
412
|
+
f"The sweep may not exist, or you may not have access to it. "
|
|
413
|
+
f"Error: {e}"
|
|
414
|
+
) from e
|
|
415
|
+
except Exception as e:
|
|
416
|
+
raise RuntimeError(f"Unexpected error fetching wandb sweep '{entity}/{project}/{sweep_id}': {e}") from e
|
|
417
|
+
|
|
418
|
+
# Download each run's data
|
|
419
|
+
downloaded_paths = []
|
|
420
|
+
failed_runs = []
|
|
421
|
+
|
|
422
|
+
for run_id in run_ids:
|
|
423
|
+
path = f"{base_path or '/tmp/wandb_runs'}/{run_id}"
|
|
424
|
+
try:
|
|
425
|
+
download_wandb_run_dir(run_id=run_id, path=path, include_history=include_history)
|
|
426
|
+
downloaded_paths.append(path)
|
|
427
|
+
except Exception as e:
|
|
428
|
+
# Log failure but continue with other runs
|
|
429
|
+
failed_runs.append((run_id, str(e)))
|
|
430
|
+
|
|
431
|
+
# If some runs failed, include that information
|
|
432
|
+
if failed_runs:
|
|
433
|
+
failed_info = ", ".join([f"{rid} ({err})" for rid, err in failed_runs])
|
|
434
|
+
if not downloaded_paths:
|
|
435
|
+
# All runs failed
|
|
436
|
+
raise RuntimeError(
|
|
437
|
+
f"Failed to download all {len(run_ids)} runs for sweep '{sweep_id}'. Failed runs: {failed_info}"
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
# Some runs succeeded, log warning but continue
|
|
441
|
+
logger.warning(
|
|
442
|
+
f"Failed to download {len(failed_runs)}/{len(run_ids)} runs for sweep '{sweep_id}'. "
|
|
443
|
+
f"Failed runs: {failed_info}"
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
return downloaded_paths
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
@flyte.trace
|
|
450
|
+
async def download_wandb_run_logs(run_id: str) -> Dir:
|
|
451
|
+
"""
|
|
452
|
+
Traced function to download wandb run logs after task completion.
|
|
453
|
+
|
|
454
|
+
This function is called automatically when `download_logs=True` is set
|
|
455
|
+
in `@wandb_init` or `wandb_config()`. The downloaded files appear as a
|
|
456
|
+
trace output in the Flyte UI.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
run_id: The wandb run ID to download.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
Dir containing the downloaded wandb run files.
|
|
463
|
+
|
|
464
|
+
Raises:
|
|
465
|
+
RuntimeError: If download fails (network error, run not found, auth failure, etc.)
|
|
466
|
+
"""
|
|
467
|
+
path = download_wandb_run_dir(run_id=run_id)
|
|
468
|
+
return await Dir.from_local(path)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@flyte.trace
|
|
472
|
+
async def download_wandb_sweep_logs(sweep_id: str) -> Dir:
|
|
473
|
+
"""
|
|
474
|
+
Traced function to download wandb sweep logs after task completion.
|
|
475
|
+
|
|
476
|
+
This function is called automatically when `download_logs=True` is set
|
|
477
|
+
in `@wandb_sweep` or `wandb_sweep_config()`. The downloaded files appear as a
|
|
478
|
+
trace output in the Flyte UI.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
sweep_id: The wandb sweep ID to download.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
Dir containing the downloaded wandb sweep run files.
|
|
485
|
+
|
|
486
|
+
Raises:
|
|
487
|
+
RuntimeError: If download fails (network error, sweep not found, auth failure, etc.)
|
|
488
|
+
"""
|
|
489
|
+
paths = download_wandb_sweep_dirs(sweep_id=sweep_id)
|
|
490
|
+
|
|
491
|
+
# Return the base directory containing all run subdirectories
|
|
492
|
+
base_path = os.path.dirname(paths[0]) if paths else "/tmp/wandb_runs"
|
|
493
|
+
return await Dir.from_local(base_path)
|