flyteplugins-wandb 2.0.0b52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,493 @@
1
+ """
2
+ ## Key features:
3
+
4
+ - Automatic W&B run initialization with `@wandb_init` decorator
5
+ - Automatic W&B links in Flyte UI pointing to runs and sweeps
6
+ - Parent/child task support with automatic run reuse
7
+ - W&B sweep creation and management with `@wandb_sweep` decorator
8
+ - Configuration management with `wandb_config()` and `wandb_sweep_config()`
9
+
10
+ ## Basic usage:
11
+
12
+ 1. Simple task with W&B logging:
13
+
14
+ ```python
15
+ from flyteplugins.wandb import wandb_init, get_wandb_run
16
+
17
+ @wandb_init(project="my-project", entity="my-team")
18
+ @env.task
19
+ async def train_model(learning_rate: float) -> str:
20
+ wandb_run = get_wandb_run()
21
+ wandb_run.log({"loss": 0.5, "learning_rate": learning_rate})
22
+ return wandb_run.id
23
+ ```
24
+
25
+ 2. Parent/Child Tasks with Run Reuse:
26
+
27
+ ```python
28
+ @wandb_init # Automatically reuses parent's run ID
29
+ @env.task
30
+ async def child_task(x: int) -> str:
31
+ wandb_run = get_wandb_run()
32
+ wandb_run.log({"child_metric": x * 2})
33
+ return wandb_run.id
34
+
35
+ @wandb_init(project="my-project", entity="my-team")
36
+ @env.task
37
+ async def parent_task() -> str:
38
+ wandb_run = get_wandb_run()
39
+ wandb_run.log({"parent_metric": 100})
40
+
41
+ # Child reuses parent's run by default (run_mode="auto")
42
+ await child_task(5)
43
+ return wandb_run.id
44
+ ```
45
+
46
+ 3. Configuration with context manager:
47
+
48
+ ```python
49
+ from flyteplugins.wandb import wandb_config
50
+
51
+ r = flyte.with_runcontext(
52
+ custom_context=wandb_config(
53
+ project="my-project",
54
+ entity="my-team",
55
+ tags=["experiment-1"]
56
+ )
57
+ ).run(train_model, learning_rate=0.001)
58
+ ```
59
+
60
+ 4. Creating new runs for child tasks:
61
+
62
+ ```python
63
+ @wandb_init(run_mode="new") # Always creates a new run
64
+ @env.task
65
+ async def independent_child() -> str:
66
+ wandb_run = get_wandb_run()
67
+ wandb_run.log({"independent_metric": 42})
68
+ return wandb_run.id
69
+ ```
70
+
71
+ 5. Running sweep agents in parallel:
72
+
73
+ ```python
74
+ import asyncio
75
+ from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id, get_wandb_context
76
+
77
+ @wandb_init
78
+ async def objective():
79
+ wandb_run = wandb.run
80
+ config = wandb_run.config
81
+ ...
82
+
83
+ wandb_run.log({"loss": loss_value})
84
+
85
+ @wandb_sweep
86
+ @env.task
87
+ async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
88
+ wandb.agent(sweep_id, function=objective, count=count, project=get_wandb_context().project)
89
+ return agent_id
90
+
91
+ @wandb_sweep
92
+ @env.task
93
+ async def run_parallel_sweep(num_agents: int = 2, trials_per_agent: int = 5) -> str:
94
+ sweep_id = get_wandb_sweep_id()
95
+
96
+ # Launch agents in parallel
97
+ agent_tasks = [
98
+ sweep_agent(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent)
99
+ for i in range(num_agents)
100
+ ]
101
+
102
+ # Wait for all agents to complete
103
+ await asyncio.gather(*agent_tasks)
104
+ return sweep_id
105
+
106
+ # Run with 2 parallel agents
107
+ r = flyte.with_runcontext(
108
+ custom_context={
109
+ **wandb_config(project="my-project", entity="my-team"),
110
+ **wandb_sweep_config(
111
+ method="random",
112
+ metric={"name": "loss", "goal": "minimize"},
113
+ parameters={
114
+ "learning_rate": {"min": 0.0001, "max": 0.1},
115
+ "batch_size": {"values": [16, 32, 64]},
116
+ }
117
+ )
118
+ }
119
+ ).run(run_parallel_sweep, num_agents=2, trials_per_agent=5)
120
+ ```
121
+
122
+ Decorator order: `@wandb_init` or `@wandb_sweep` must be the outermost decorator:
123
+
124
+ ```python
125
+ @wandb_init
126
+ @env.task
127
+ async def my_task():
128
+ ...
129
+ ```
130
+ """
131
+
132
+ import json
133
+ import logging
134
+ import os
135
+ from typing import Optional
136
+
137
+ import flyte
138
+ from flyte.io import Dir
139
+
140
+ import wandb
141
+
142
+ from ._context import (
143
+ get_wandb_context,
144
+ get_wandb_sweep_context,
145
+ wandb_config,
146
+ wandb_sweep_config,
147
+ )
148
+ from ._decorator import wandb_init, wandb_sweep
149
+ from ._link import Wandb, WandbSweep
150
+
151
+ logger = logging.getLogger(__name__)
152
+
153
+
154
+ __all__ = [
155
+ "Wandb",
156
+ "WandbSweep",
157
+ "download_wandb_run_dir",
158
+ "download_wandb_run_logs",
159
+ "download_wandb_sweep_dirs",
160
+ "download_wandb_sweep_logs",
161
+ "get_wandb_context",
162
+ "get_wandb_run",
163
+ "get_wandb_run_dir",
164
+ "get_wandb_sweep_context",
165
+ "get_wandb_sweep_id",
166
+ "wandb_config",
167
+ "wandb_init",
168
+ "wandb_sweep",
169
+ "wandb_sweep_config",
170
+ ]
171
+
172
+
173
+ __version__ = "0.1.0"
174
+
175
+
176
+ def get_wandb_run():
177
+ """
178
+ Get the current wandb run if within a `@wandb_init` decorated task or trace.
179
+
180
+ The run is initialized when the `@wandb_init` context manager is entered.
181
+ Returns None if not within a `wandb_init` context.
182
+
183
+ Returns:
184
+ `wandb.sdk.wandb_run.Run` | `None`: The current wandb run object or None.
185
+ """
186
+ ctx = flyte.ctx()
187
+ if not ctx or not ctx.data:
188
+ return None
189
+
190
+ return ctx.data.get("_wandb_run")
191
+
192
+
193
+ def get_wandb_sweep_id() -> str | None:
194
+ """
195
+ Get the current wandb `sweep_id` if within a `@wandb_sweep` decorated task.
196
+
197
+ Returns `None` if not within a `wandb_sweep` context.
198
+
199
+ Returns:
200
+ `str` | `None`: The sweep ID or None.
201
+ """
202
+ ctx = flyte.ctx()
203
+ if not ctx or not ctx.custom_context:
204
+ return None
205
+
206
+ return ctx.custom_context.get("_wandb_sweep_id")
207
+
208
+
209
+ def get_wandb_run_dir() -> Optional[str]:
210
+ """
211
+ Get the local directory path for the current wandb run.
212
+
213
+ Use this for accessing files written by the current task without any
214
+ network calls. For accessing files from other tasks (or after a task
215
+ completes), use `download_wandb_run_dir()` instead.
216
+
217
+ Returns:
218
+ Local path to wandb run directory (`wandb.run.dir`) or `None` if no
219
+ active run.
220
+ """
221
+ run = get_wandb_run()
222
+ if run is None:
223
+ return None
224
+ return run.dir
225
+
226
+
227
+ def download_wandb_run_dir(
228
+ run_id: Optional[str] = None,
229
+ path: Optional[str] = None,
230
+ include_history: bool = True,
231
+ ) -> str:
232
+ """
233
+ Download wandb run data from wandb cloud.
234
+
235
+ Downloads all run files and optionally exports metrics history to JSON.
236
+ This enables access to wandb data from any task or after workflow completion.
237
+
238
+ Downloaded contents:
239
+
240
+ - summary.json - final summary metrics (always exported)
241
+ - metrics_history.json - step-by-step metrics (if include_history=True)
242
+ - Plus any files synced by wandb (requirements.txt, wandb_metadata.json, etc.)
243
+
244
+ Args:
245
+ run_id: The wandb run ID to download. If `None`, uses the current run's ID
246
+ from context (useful for shared runs across tasks).
247
+ path: Local directory to download files to. If `None`, downloads to
248
+ `/tmp/wandb_runs/{run_id}`.
249
+ include_history: If `True`, exports the step-by-step metrics history
250
+ to `metrics_history.json`. Defaults to `True`.
251
+
252
+ Returns:
253
+ Local path where files were downloaded.
254
+
255
+ Raises:
256
+ `RuntimeError`: If no `run_id` provided and no active run in context.
257
+ `wandb.errors.CommError`: If run not found in wandb cloud.
258
+
259
+ Note:
260
+ There may be a brief delay between when files are written locally and
261
+ when they're available in wandb cloud. For immediate local access
262
+ within the same task, use `get_wandb_run_dir()` instead.
263
+ """
264
+ # Determine run_id
265
+ if run_id is None:
266
+ ctx = flyte.ctx()
267
+ if ctx and ctx.custom_context:
268
+ run_id = ctx.custom_context.get("_wandb_run_id")
269
+ if run_id is None:
270
+ run = get_wandb_run()
271
+ if run:
272
+ run_id = run.id
273
+ if run_id is None:
274
+ raise RuntimeError(
275
+ "No run_id provided and no active wandb run found in context. "
276
+ "Provide a run_id explicitly or call from within a @wandb_init task."
277
+ )
278
+
279
+ # Get entity/project from context
280
+ wandb_ctx = get_wandb_context()
281
+ entity = wandb_ctx.entity if wandb_ctx else None
282
+ project = wandb_ctx.project if wandb_ctx else None
283
+
284
+ # Build run path for API
285
+ if entity and project:
286
+ run_path = f"{entity}/{project}/{run_id}"
287
+ elif project:
288
+ run_path = f"{project}/{run_id}"
289
+ else:
290
+ # wandb API can sometimes work with just run_id if logged in
291
+ run_path = run_id
292
+
293
+ # Set download path
294
+ if path is None:
295
+ path = f"/tmp/wandb_runs/{run_id}"
296
+
297
+ # Ensure directory exists
298
+ try:
299
+ os.makedirs(path, exist_ok=True)
300
+ except OSError as e:
301
+ raise RuntimeError(f"Failed to create download directory {path}: {e}") from e
302
+
303
+ # Download files from wandb cloud
304
+ try:
305
+ api = wandb.Api()
306
+ api_run = api.run(run_path)
307
+ except wandb.errors.AuthenticationError as e:
308
+ # Must check AuthenticationError before CommError (it's a subclass)
309
+ raise RuntimeError(
310
+ f"Authentication failed when accessing wandb run '{run_path}'. "
311
+ f"Please ensure WANDB_API_KEY is set correctly. Error: {e}"
312
+ ) from e
313
+ except wandb.errors.CommError as e:
314
+ raise RuntimeError(
315
+ f"Failed to fetch wandb run '{run_path}' from wandb cloud. "
316
+ f"The run may not exist, or you may not have access to it. "
317
+ f"Error: {e}"
318
+ ) from e
319
+ except Exception as e:
320
+ raise RuntimeError(f"Unexpected error fetching wandb run '{run_path}': {e}") from e
321
+
322
+ try:
323
+ for file in api_run.files():
324
+ file.download(root=path, replace=True)
325
+ except Exception as e:
326
+ raise RuntimeError(f"Failed to download files for run '{run_id}': {e}") from e
327
+
328
+ # Export summary to JSON
329
+ try:
330
+ summary_data = dict(api_run.summary)
331
+ if summary_data:
332
+ with open(os.path.join(path, "summary.json"), "w") as f:
333
+ json.dump(summary_data, f, indent=2, default=str)
334
+ except (IOError, OSError) as e:
335
+ raise RuntimeError(f"Failed to write summary.json for run '{run_id}': {e}") from e
336
+ except Exception as e:
337
+ raise RuntimeError(f"Failed to export summary data for run '{run_id}': {e}") from e
338
+
339
+ # Export metrics history to JSON
340
+ if include_history:
341
+ try:
342
+ history = api_run.history()
343
+ if history:
344
+ with open(os.path.join(path, "metrics_history.json"), "w") as f:
345
+ json.dump(history, f, indent=2, default=str)
346
+ except (IOError, OSError) as e:
347
+ raise RuntimeError(f"Failed to write metrics_history.json for run '{run_id}': {e}") from e
348
+ except Exception as e:
349
+ raise RuntimeError(f"Failed to export history data for run '{run_id}': {e}") from e
350
+
351
+ return path
352
+
353
+
354
+ def download_wandb_sweep_dirs(
355
+ sweep_id: Optional[str] = None,
356
+ base_path: Optional[str] = None,
357
+ include_history: bool = True,
358
+ ) -> list[str]:
359
+ """
360
+ Download all run data for a wandb sweep.
361
+
362
+ Queries the wandb API for all runs in the sweep and downloads their files
363
+ and metrics history. This is useful for collecting results from all sweep
364
+ trials after completion.
365
+
366
+ Args:
367
+ sweep_id: The wandb sweep ID. If `None`, uses the current sweep's ID
368
+ from context (set by `@wandb_sweep` decorator).
369
+ base_path: Base directory to download files to. Each run's files will be
370
+ in a subdirectory named by run_id. If `None`, uses `/tmp/wandb_runs/`.
371
+ include_history: If `True`, exports the step-by-step metrics history
372
+ to metrics_history.json for each run. Defaults to `True`.
373
+
374
+ Returns:
375
+ List of local paths where run data was downloaded.
376
+
377
+ Raises:
378
+ RuntimeError: If no sweep_id provided and no active sweep in context.
379
+ wandb.errors.CommError: If sweep not found in wandb cloud.
380
+ """
381
+ # Determine sweep_id
382
+ if sweep_id is None:
383
+ sweep_id = get_wandb_sweep_id()
384
+ if sweep_id is None:
385
+ raise RuntimeError(
386
+ "No sweep_id provided and no active wandb sweep found in context. "
387
+ "Provide a sweep_id explicitly or call from within a @wandb_sweep task."
388
+ )
389
+
390
+ # Get entity/project from context
391
+ wandb_ctx = get_wandb_context()
392
+ entity = wandb_ctx.entity if wandb_ctx else None
393
+ project = wandb_ctx.project if wandb_ctx else None
394
+
395
+ if not entity or not project:
396
+ raise RuntimeError("Cannot query sweep without entity and project. Set them via wandb_config().")
397
+
398
+ # Query sweep runs via wandb API
399
+ try:
400
+ api = wandb.Api()
401
+ sweep = api.sweep(f"{entity}/{project}/{sweep_id}")
402
+ run_ids = [run.id for run in sweep.runs]
403
+ except wandb.errors.AuthenticationError as e:
404
+ # Must check AuthenticationError before CommError (it's a subclass)
405
+ raise RuntimeError(
406
+ f"Authentication failed when accessing wandb sweep '{entity}/{project}/{sweep_id}'. "
407
+ f"Please ensure WANDB_API_KEY is set correctly. Error: {e}"
408
+ ) from e
409
+ except wandb.errors.CommError as e:
410
+ raise RuntimeError(
411
+ f"Failed to fetch wandb sweep '{entity}/{project}/{sweep_id}' from wandb cloud. "
412
+ f"The sweep may not exist, or you may not have access to it. "
413
+ f"Error: {e}"
414
+ ) from e
415
+ except Exception as e:
416
+ raise RuntimeError(f"Unexpected error fetching wandb sweep '{entity}/{project}/{sweep_id}': {e}") from e
417
+
418
+ # Download each run's data
419
+ downloaded_paths = []
420
+ failed_runs = []
421
+
422
+ for run_id in run_ids:
423
+ path = f"{base_path or '/tmp/wandb_runs'}/{run_id}"
424
+ try:
425
+ download_wandb_run_dir(run_id=run_id, path=path, include_history=include_history)
426
+ downloaded_paths.append(path)
427
+ except Exception as e:
428
+ # Log failure but continue with other runs
429
+ failed_runs.append((run_id, str(e)))
430
+
431
+ # If some runs failed, include that information
432
+ if failed_runs:
433
+ failed_info = ", ".join([f"{rid} ({err})" for rid, err in failed_runs])
434
+ if not downloaded_paths:
435
+ # All runs failed
436
+ raise RuntimeError(
437
+ f"Failed to download all {len(run_ids)} runs for sweep '{sweep_id}'. Failed runs: {failed_info}"
438
+ )
439
+ else:
440
+ # Some runs succeeded, log warning but continue
441
+ logger.warning(
442
+ f"Failed to download {len(failed_runs)}/{len(run_ids)} runs for sweep '{sweep_id}'. "
443
+ f"Failed runs: {failed_info}"
444
+ )
445
+
446
+ return downloaded_paths
447
+
448
+
449
+ @flyte.trace
450
+ async def download_wandb_run_logs(run_id: str) -> Dir:
451
+ """
452
+ Traced function to download wandb run logs after task completion.
453
+
454
+ This function is called automatically when `download_logs=True` is set
455
+ in `@wandb_init` or `wandb_config()`. The downloaded files appear as a
456
+ trace output in the Flyte UI.
457
+
458
+ Args:
459
+ run_id: The wandb run ID to download.
460
+
461
+ Returns:
462
+ Dir containing the downloaded wandb run files.
463
+
464
+ Raises:
465
+ RuntimeError: If download fails (network error, run not found, auth failure, etc.)
466
+ """
467
+ path = download_wandb_run_dir(run_id=run_id)
468
+ return await Dir.from_local(path)
469
+
470
+
471
+ @flyte.trace
472
+ async def download_wandb_sweep_logs(sweep_id: str) -> Dir:
473
+ """
474
+ Traced function to download wandb sweep logs after task completion.
475
+
476
+ This function is called automatically when `download_logs=True` is set
477
+ in `@wandb_sweep` or `wandb_sweep_config()`. The downloaded files appear as a
478
+ trace output in the Flyte UI.
479
+
480
+ Args:
481
+ sweep_id: The wandb sweep ID to download.
482
+
483
+ Returns:
484
+ Dir containing the downloaded wandb sweep run files.
485
+
486
+ Raises:
487
+ RuntimeError: If download fails (network error, sweep not found, auth failure, etc.)
488
+ """
489
+ paths = download_wandb_sweep_dirs(sweep_id=sweep_id)
490
+
491
+ # Return the base directory containing all run subdirectories
492
+ base_path = os.path.dirname(paths[0]) if paths else "/tmp/wandb_runs"
493
+ return await Dir.from_local(base_path)