flyteplugins-wandb 2.0.0b52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/wandb/__init__.py +493 -0
- flyteplugins/wandb/_context.py +381 -0
- flyteplugins/wandb/_decorator.py +417 -0
- flyteplugins/wandb/_link.py +149 -0
- flyteplugins_wandb-2.0.0b52.dist-info/METADATA +34 -0
- flyteplugins_wandb-2.0.0b52.dist-info/RECORD +8 -0
- flyteplugins_wandb-2.0.0b52.dist-info/WHEEL +5 -0
- flyteplugins_wandb-2.0.0b52.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from inspect import iscoroutinefunction
|
|
6
|
+
from typing import Any, Callable, Optional, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
import flyte
|
|
9
|
+
from flyte._task import AsyncFunctionTaskTemplate
|
|
10
|
+
|
|
11
|
+
import wandb
|
|
12
|
+
|
|
13
|
+
from ._context import RunMode, get_wandb_context, get_wandb_sweep_context
|
|
14
|
+
from ._link import Wandb, WandbSweep
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _build_init_kwargs() -> dict[str, Any]:
|
|
22
|
+
"""Build wandb.init() kwargs from current context config."""
|
|
23
|
+
context_config = get_wandb_context()
|
|
24
|
+
if context_config:
|
|
25
|
+
config_dict = asdict(context_config)
|
|
26
|
+
extra_kwargs = config_dict.pop("kwargs", None) or {}
|
|
27
|
+
|
|
28
|
+
# Remove Flyte-specific fields that shouldn't be passed to wandb.init()
|
|
29
|
+
config_dict.pop("run_mode", None)
|
|
30
|
+
config_dict.pop("download_logs", None)
|
|
31
|
+
|
|
32
|
+
# Filter out None values
|
|
33
|
+
filtered_config = {k: v for k, v in config_dict.items() if v is not None}
|
|
34
|
+
|
|
35
|
+
return {**extra_kwargs, **filtered_config}
|
|
36
|
+
return {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@contextmanager
|
|
40
|
+
def _wandb_run(
|
|
41
|
+
run_mode: RunMode = "auto",
|
|
42
|
+
func: bool = False,
|
|
43
|
+
**decorator_kwargs,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Context manager for wandb run lifecycle.
|
|
47
|
+
|
|
48
|
+
Initializes wandb.init() when the context is entered.
|
|
49
|
+
The initialized run is available via get_wandb_run().
|
|
50
|
+
"""
|
|
51
|
+
# Try to get Flyte context
|
|
52
|
+
ctx = flyte.ctx()
|
|
53
|
+
|
|
54
|
+
# This enables @wandb_init to work in wandb.agent() callbacks (sweep objectives)
|
|
55
|
+
if func and ctx is None:
|
|
56
|
+
# Use config from decorator params (no lazy init for fallback mode)
|
|
57
|
+
run = wandb.init(**decorator_kwargs)
|
|
58
|
+
try:
|
|
59
|
+
yield run
|
|
60
|
+
finally:
|
|
61
|
+
run.finish()
|
|
62
|
+
return
|
|
63
|
+
elif func and ctx:
|
|
64
|
+
raise RuntimeError(
|
|
65
|
+
"@wandb_init cannot be applied to traces. Traces can access the parent's wandb run via get_wandb_run()."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Save existing state to restore later
|
|
69
|
+
saved_run_id = ctx.custom_context.get("_wandb_run_id")
|
|
70
|
+
saved_run = ctx.data.get("_wandb_run")
|
|
71
|
+
|
|
72
|
+
# Build init kwargs from context
|
|
73
|
+
context_init_kwargs = _build_init_kwargs()
|
|
74
|
+
init_kwargs = {**context_init_kwargs, **decorator_kwargs}
|
|
75
|
+
|
|
76
|
+
# Check if this is a trace accessing parent's run
|
|
77
|
+
run = ctx.data.get("_wandb_run")
|
|
78
|
+
if run:
|
|
79
|
+
# This is a trace - yield existing run without initializing
|
|
80
|
+
try:
|
|
81
|
+
yield run
|
|
82
|
+
finally:
|
|
83
|
+
pass # Don't clean up - parent owns this run
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
# Get current action name for run ID generation
|
|
87
|
+
current_action = ctx.action.name
|
|
88
|
+
|
|
89
|
+
# Determine if we should reuse parent's run
|
|
90
|
+
should_reuse = False
|
|
91
|
+
if run_mode == "shared":
|
|
92
|
+
should_reuse = True
|
|
93
|
+
elif run_mode == "auto":
|
|
94
|
+
should_reuse = bool(saved_run_id)
|
|
95
|
+
|
|
96
|
+
# Determine run ID
|
|
97
|
+
if "id" not in init_kwargs or init_kwargs["id"] is None:
|
|
98
|
+
if should_reuse:
|
|
99
|
+
if not saved_run_id:
|
|
100
|
+
raise RuntimeError("Cannot reuse parent run: no parent run ID found")
|
|
101
|
+
init_kwargs["id"] = saved_run_id
|
|
102
|
+
else:
|
|
103
|
+
init_kwargs["id"] = f"{ctx.action.run_name}-{current_action}"
|
|
104
|
+
|
|
105
|
+
# Configure reinit parameter (only for local mode)
|
|
106
|
+
# In remote/shared mode, wandb handles run creation/joining automatically
|
|
107
|
+
if flyte.ctx().mode == "local":
|
|
108
|
+
if should_reuse:
|
|
109
|
+
if "reinit" not in init_kwargs:
|
|
110
|
+
init_kwargs["reinit"] = "return_previous"
|
|
111
|
+
else:
|
|
112
|
+
init_kwargs["reinit"] = "create_new"
|
|
113
|
+
|
|
114
|
+
# Configure remote mode settings
|
|
115
|
+
if flyte.ctx().mode == "remote":
|
|
116
|
+
is_primary = not should_reuse
|
|
117
|
+
existing_settings = init_kwargs.get("settings", {})
|
|
118
|
+
|
|
119
|
+
shared_config = {
|
|
120
|
+
"mode": "shared",
|
|
121
|
+
"x_primary": is_primary,
|
|
122
|
+
"x_label": current_action,
|
|
123
|
+
}
|
|
124
|
+
if not is_primary:
|
|
125
|
+
shared_config["x_update_finish_state"] = False
|
|
126
|
+
|
|
127
|
+
init_kwargs["settings"] = wandb.Settings(**{**existing_settings, **shared_config})
|
|
128
|
+
|
|
129
|
+
# Initialize wandb
|
|
130
|
+
run = wandb.init(**init_kwargs)
|
|
131
|
+
|
|
132
|
+
# Store run ID in custom_context (shared with child tasks and accessible to links)
|
|
133
|
+
ctx.custom_context["_wandb_run_id"] = run.id
|
|
134
|
+
|
|
135
|
+
# Store run object in ctx.data (task-local only and accessible to traces)
|
|
136
|
+
ctx.data["_wandb_run"] = run
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
yield run
|
|
140
|
+
finally:
|
|
141
|
+
# Determine if this is a primary run
|
|
142
|
+
is_primary_run = run_mode == "new" or (run_mode == "auto" and saved_run_id is None)
|
|
143
|
+
|
|
144
|
+
if run:
|
|
145
|
+
# Different cleanup logic for local vs remote mode
|
|
146
|
+
should_finish = False
|
|
147
|
+
|
|
148
|
+
if flyte.ctx().mode == "remote":
|
|
149
|
+
# In remote/shared mode, always call run.finish() to flush data
|
|
150
|
+
# For secondary tasks, x_update_finish_state=False prevents actually finishing
|
|
151
|
+
# For primary tasks, this properly finishes the run
|
|
152
|
+
should_finish = True
|
|
153
|
+
elif is_primary_run:
|
|
154
|
+
# In local mode, only primary tasks should call run.finish()
|
|
155
|
+
# Secondary tasks reuse the parent's run object, so they must not finish it
|
|
156
|
+
should_finish = True
|
|
157
|
+
|
|
158
|
+
if should_finish:
|
|
159
|
+
try:
|
|
160
|
+
run.finish(exit_code=0)
|
|
161
|
+
except Exception:
|
|
162
|
+
try:
|
|
163
|
+
run.finish(exit_code=1)
|
|
164
|
+
except Exception:
|
|
165
|
+
pass
|
|
166
|
+
raise
|
|
167
|
+
|
|
168
|
+
# Restore run ID
|
|
169
|
+
if saved_run_id is not None:
|
|
170
|
+
ctx.custom_context["_wandb_run_id"] = saved_run_id
|
|
171
|
+
else:
|
|
172
|
+
ctx.custom_context.pop("_wandb_run_id", None)
|
|
173
|
+
|
|
174
|
+
# Restore run object
|
|
175
|
+
if saved_run is not None:
|
|
176
|
+
ctx.data["_wandb_run"] = saved_run
|
|
177
|
+
else:
|
|
178
|
+
ctx.data.pop("_wandb_run", None)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def wandb_init(
|
|
182
|
+
_func: Optional[F] = None,
|
|
183
|
+
*,
|
|
184
|
+
run_mode: RunMode = "auto",
|
|
185
|
+
download_logs: Optional[bool] = None,
|
|
186
|
+
project: Optional[str] = None,
|
|
187
|
+
entity: Optional[str] = None,
|
|
188
|
+
**kwargs,
|
|
189
|
+
) -> F:
|
|
190
|
+
"""
|
|
191
|
+
Decorator to automatically initialize wandb for Flyte tasks and wandb sweep objectives.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
run_mode: Controls whether to create a new W&B run or share an existing one:
|
|
195
|
+
|
|
196
|
+
1. "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run
|
|
197
|
+
2. "new": Always creates a new wandb run with a unique ID
|
|
198
|
+
3. "shared": Always shares the parent's run ID (useful for child tasks)
|
|
199
|
+
download_logs: If `True`, downloads wandb run files after task completes
|
|
200
|
+
and shows them as a trace output in the Flyte UI. If None, uses
|
|
201
|
+
the value from `wandb_config()` context if set.
|
|
202
|
+
project: W&B project name (overrides context config if provided)
|
|
203
|
+
entity: W&B entity/team name (overrides context config if provided)
|
|
204
|
+
**kwargs: Additional `wandb.init()` parameters (tags, config, mode, etc.)
|
|
205
|
+
|
|
206
|
+
Decorator Order:
|
|
207
|
+
For tasks, @wandb_init must be the outermost decorator:
|
|
208
|
+
@wandb_init
|
|
209
|
+
@env.task
|
|
210
|
+
async def my_task():
|
|
211
|
+
...
|
|
212
|
+
|
|
213
|
+
This decorator:
|
|
214
|
+
1. Initializes wandb when the context manager is entered
|
|
215
|
+
2. Auto-generates unique run ID from Flyte action context if not provided
|
|
216
|
+
3. Makes the run available via get_wandb_run()
|
|
217
|
+
4. Automatically adds a W&B link to the task in the Flyte UI
|
|
218
|
+
5. Automatically finishes the run after completion
|
|
219
|
+
6. Optionally downloads run logs as a trace output (if download_logs=True)
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
def decorator(func: F) -> F:
|
|
223
|
+
# Build decorator kwargs dict to pass to _wandb_run
|
|
224
|
+
decorator_kwargs = {}
|
|
225
|
+
if project is not None:
|
|
226
|
+
decorator_kwargs["project"] = project
|
|
227
|
+
if entity is not None:
|
|
228
|
+
decorator_kwargs["entity"] = entity
|
|
229
|
+
decorator_kwargs.update(kwargs)
|
|
230
|
+
|
|
231
|
+
# Check if it's a Flyte task (AsyncFunctionTaskTemplate)
|
|
232
|
+
if isinstance(func, AsyncFunctionTaskTemplate):
|
|
233
|
+
# Create a Wandb link
|
|
234
|
+
# Even if run_mode="shared", we still add a link - it will point to the parent's run
|
|
235
|
+
wandb_link = Wandb(project=project, entity=entity, run_mode=run_mode)
|
|
236
|
+
|
|
237
|
+
# Get existing links from the task and add wandb link
|
|
238
|
+
existing_links = getattr(func, "links", ())
|
|
239
|
+
|
|
240
|
+
# Use override to properly add the link to the task
|
|
241
|
+
func = func.override(links=(*existing_links, wandb_link))
|
|
242
|
+
|
|
243
|
+
# Wrap the task's execute method with wandb_run
|
|
244
|
+
original_execute = func.execute
|
|
245
|
+
|
|
246
|
+
async def wrapped_execute(*args, **exec_kwargs):
|
|
247
|
+
with _wandb_run(run_mode=run_mode, **decorator_kwargs) as run:
|
|
248
|
+
result = await original_execute(*args, **exec_kwargs)
|
|
249
|
+
|
|
250
|
+
# After run finishes, optionally download logs
|
|
251
|
+
should_download = download_logs
|
|
252
|
+
if should_download is None:
|
|
253
|
+
# Check context config
|
|
254
|
+
ctx_config = get_wandb_context()
|
|
255
|
+
should_download = ctx_config.download_logs if ctx_config else False
|
|
256
|
+
|
|
257
|
+
if should_download and run:
|
|
258
|
+
from . import download_wandb_run_logs
|
|
259
|
+
|
|
260
|
+
await download_wandb_run_logs(run.id)
|
|
261
|
+
|
|
262
|
+
return result
|
|
263
|
+
|
|
264
|
+
func.execute = wrapped_execute
|
|
265
|
+
|
|
266
|
+
return cast(F, func)
|
|
267
|
+
# Regular function
|
|
268
|
+
else:
|
|
269
|
+
if iscoroutinefunction(func):
|
|
270
|
+
|
|
271
|
+
@functools.wraps(func)
|
|
272
|
+
async def async_wrapper(*args, **wrapper_kwargs):
|
|
273
|
+
with _wandb_run(run_mode=run_mode, func=True, **decorator_kwargs):
|
|
274
|
+
return await func(*args, **wrapper_kwargs)
|
|
275
|
+
|
|
276
|
+
return cast(F, async_wrapper)
|
|
277
|
+
else:
|
|
278
|
+
|
|
279
|
+
@functools.wraps(func)
|
|
280
|
+
def sync_wrapper(*args, **wrapper_kwargs):
|
|
281
|
+
with _wandb_run(run_mode=run_mode, func=True, **decorator_kwargs):
|
|
282
|
+
return func(*args, **wrapper_kwargs)
|
|
283
|
+
|
|
284
|
+
return cast(F, sync_wrapper)
|
|
285
|
+
|
|
286
|
+
if _func is None:
|
|
287
|
+
return decorator
|
|
288
|
+
return decorator(_func)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@contextmanager
|
|
292
|
+
def _create_sweep(project: Optional[str] = None, entity: Optional[str] = None, **decorator_kwargs):
|
|
293
|
+
"""Context manager for wandb sweep creation."""
|
|
294
|
+
ctx = flyte.ctx()
|
|
295
|
+
|
|
296
|
+
# Check if a sweep already exists in context - reuse it instead of creating new
|
|
297
|
+
existing_sweep_id = ctx.custom_context.get("_wandb_sweep_id")
|
|
298
|
+
if existing_sweep_id:
|
|
299
|
+
yield existing_sweep_id
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
# Get sweep config from context
|
|
303
|
+
sweep_config = get_wandb_sweep_context()
|
|
304
|
+
if not sweep_config:
|
|
305
|
+
raise RuntimeError(
|
|
306
|
+
"No wandb sweep config found. Use wandb_sweep_config() "
|
|
307
|
+
"with flyte.with_runcontext() or as a context manager."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Get wandb config for project/entity (fallback)
|
|
311
|
+
wandb_config = get_wandb_context()
|
|
312
|
+
|
|
313
|
+
# Priority: decorator kwargs > sweep config > wandb config
|
|
314
|
+
project = project or sweep_config.project or (wandb_config.project if wandb_config else None)
|
|
315
|
+
entity = entity or sweep_config.entity or (wandb_config.entity if wandb_config else None)
|
|
316
|
+
prior_runs = sweep_config.prior_runs or []
|
|
317
|
+
|
|
318
|
+
# Get sweep config dict
|
|
319
|
+
sweep_dict = sweep_config.to_sweep_config()
|
|
320
|
+
|
|
321
|
+
# Generate deterministic sweep name if not provided
|
|
322
|
+
if "name" not in sweep_dict or sweep_dict["name"] is None:
|
|
323
|
+
sweep_dict["name"] = f"{ctx.action.run_name}-{ctx.action.name}"
|
|
324
|
+
|
|
325
|
+
# Create the sweep
|
|
326
|
+
sweep_id = wandb.sweep(
|
|
327
|
+
sweep=sweep_dict,
|
|
328
|
+
project=project,
|
|
329
|
+
entity=entity,
|
|
330
|
+
prior_runs=prior_runs,
|
|
331
|
+
**decorator_kwargs,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Store sweep_id in context (accessible to links)
|
|
335
|
+
ctx.custom_context["_wandb_sweep_id"] = sweep_id
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
yield sweep_id
|
|
339
|
+
finally:
|
|
340
|
+
# Clean up sweep_id from context
|
|
341
|
+
ctx.custom_context.pop("_wandb_sweep_id", None)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def wandb_sweep(
|
|
345
|
+
_func: Optional[F] = None,
|
|
346
|
+
*,
|
|
347
|
+
project: Optional[str] = None,
|
|
348
|
+
entity: Optional[str] = None,
|
|
349
|
+
download_logs: Optional[bool] = None,
|
|
350
|
+
**kwargs,
|
|
351
|
+
) -> F:
|
|
352
|
+
"""
|
|
353
|
+
Decorator to create a wandb sweep and make `sweep_id` available.
|
|
354
|
+
|
|
355
|
+
This decorator:
|
|
356
|
+
1. Creates a wandb sweep using config from context
|
|
357
|
+
2. Makes `sweep_id` available via `get_wandb_sweep_id()`
|
|
358
|
+
3. Automatically adds a W&B sweep link to the task
|
|
359
|
+
4. Optionally downloads all sweep run logs as a trace output (if `download_logs=True`)
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
project: W&B project name (overrides context config if provided)
|
|
363
|
+
entity: W&B entity/team name (overrides context config if provided)
|
|
364
|
+
download_logs: if `True`, downloads all sweep run files after task completes
|
|
365
|
+
and shows them as a trace output in the Flyte UI. If None, uses
|
|
366
|
+
the value from wandb_sweep_config() context if set.
|
|
367
|
+
**kwargs: additional `wandb.sweep()` parameters
|
|
368
|
+
|
|
369
|
+
Decorator Order:
|
|
370
|
+
For tasks, @wandb_sweep must be the outermost decorator:
|
|
371
|
+
@wandb_sweep
|
|
372
|
+
@env.task
|
|
373
|
+
async def my_task():
|
|
374
|
+
...
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
def decorator(func: F) -> F:
|
|
378
|
+
# Check if it's a Flyte task (AsyncFunctionTaskTemplate)
|
|
379
|
+
if isinstance(func, AsyncFunctionTaskTemplate):
|
|
380
|
+
# Create a WandbSweep link
|
|
381
|
+
wandb_sweep_link = WandbSweep()
|
|
382
|
+
|
|
383
|
+
# Get existing links from the task and add wandb sweep link
|
|
384
|
+
existing_links = getattr(func, "links", ())
|
|
385
|
+
|
|
386
|
+
# Use override to properly add the link to the task
|
|
387
|
+
func = func.override(links=(*existing_links, wandb_sweep_link))
|
|
388
|
+
|
|
389
|
+
original_execute = func.execute
|
|
390
|
+
|
|
391
|
+
async def wrapped_execute(*args, **exec_kwargs):
|
|
392
|
+
with _create_sweep(project=project, entity=entity, **kwargs) as sweep_id:
|
|
393
|
+
result = await original_execute(*args, **exec_kwargs)
|
|
394
|
+
|
|
395
|
+
# After sweep finishes, optionally download logs
|
|
396
|
+
should_download = download_logs
|
|
397
|
+
if should_download is None:
|
|
398
|
+
# Check context config
|
|
399
|
+
sweep_config = get_wandb_sweep_context()
|
|
400
|
+
should_download = sweep_config.download_logs if sweep_config else False
|
|
401
|
+
|
|
402
|
+
if should_download and sweep_id:
|
|
403
|
+
from . import download_wandb_sweep_logs
|
|
404
|
+
|
|
405
|
+
await download_wandb_sweep_logs(sweep_id)
|
|
406
|
+
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
func.execute = wrapped_execute
|
|
410
|
+
|
|
411
|
+
return cast(F, func)
|
|
412
|
+
else:
|
|
413
|
+
raise RuntimeError("@wandb_sweep can only be used with Flyte tasks.")
|
|
414
|
+
|
|
415
|
+
if _func is None:
|
|
416
|
+
return decorator
|
|
417
|
+
return decorator(_func)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
|
|
4
|
+
from flyte import Link
|
|
5
|
+
|
|
6
|
+
from ._context import RunMode
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Wandb(Link):
|
|
11
|
+
"""
|
|
12
|
+
Generates a Weights & Biases run link.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
host: Base W&B host URL
|
|
16
|
+
project: W&B project name (overrides context config if provided)
|
|
17
|
+
entity: W&B entity/team name (overrides context config if provided)
|
|
18
|
+
run_mode: Controls whether to create a new W&B run or share an existing one:
|
|
19
|
+
|
|
20
|
+
1. "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run
|
|
21
|
+
2. "new": Always creates a new wandb run with a unique ID
|
|
22
|
+
3. "shared": Always shares the parent's run ID (useful for child tasks)
|
|
23
|
+
id: Optional W&B run ID (overrides context config if provided)
|
|
24
|
+
name: Link name in the Flyte UI
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
host: str = "https://wandb.ai"
|
|
28
|
+
project: Optional[str] = None
|
|
29
|
+
entity: Optional[str] = None
|
|
30
|
+
run_mode: RunMode = "auto"
|
|
31
|
+
id: Optional[str] = None
|
|
32
|
+
name: str = "Weights & Biases"
|
|
33
|
+
|
|
34
|
+
def get_link(
|
|
35
|
+
self,
|
|
36
|
+
run_name: str,
|
|
37
|
+
project: str,
|
|
38
|
+
domain: str,
|
|
39
|
+
context: Dict[str, str],
|
|
40
|
+
parent_action_name: str,
|
|
41
|
+
action_name: str,
|
|
42
|
+
pod_name: str,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> str:
|
|
45
|
+
# Get project and entity from decorator values or context
|
|
46
|
+
wandb_project = self.project
|
|
47
|
+
wandb_entity = self.entity
|
|
48
|
+
wandb_run_id = None
|
|
49
|
+
user_provided_id = self.id # Prioritize ID provided at link creation time
|
|
50
|
+
run_mode = self.run_mode # Defaults to "auto"
|
|
51
|
+
|
|
52
|
+
if context:
|
|
53
|
+
# Try to get from context if not provided at decoration time
|
|
54
|
+
if not wandb_project:
|
|
55
|
+
wandb_project = context.get("wandb_project")
|
|
56
|
+
if not wandb_entity:
|
|
57
|
+
wandb_entity = context.get("wandb_entity")
|
|
58
|
+
|
|
59
|
+
# Get parent's run ID if available (set by parent task)
|
|
60
|
+
parent_run_id = context.get("_wandb_run_id")
|
|
61
|
+
|
|
62
|
+
# Check if user provided a custom run ID in wandb_config (lower priority than self.id)
|
|
63
|
+
if not user_provided_id:
|
|
64
|
+
user_provided_id = context.get("wandb_id")
|
|
65
|
+
else:
|
|
66
|
+
parent_run_id = None
|
|
67
|
+
|
|
68
|
+
# If we don't have project/entity, we can't create a valid link
|
|
69
|
+
if not wandb_project or not wandb_entity:
|
|
70
|
+
return self.host
|
|
71
|
+
|
|
72
|
+
# Determine run ID based on run_mode setting
|
|
73
|
+
if run_mode == "new":
|
|
74
|
+
# Always create new run - use user-provided ID if available, otherwise generate
|
|
75
|
+
wandb_run_id = user_provided_id or f"{run_name}-{action_name}"
|
|
76
|
+
elif run_mode == "shared":
|
|
77
|
+
# Always reuse parent's run
|
|
78
|
+
if parent_run_id:
|
|
79
|
+
wandb_run_id = parent_run_id
|
|
80
|
+
else:
|
|
81
|
+
# Can't generate link without parent run ID
|
|
82
|
+
return f"{self.host}/{wandb_entity}/{wandb_project}"
|
|
83
|
+
else: # run_mode == "auto"
|
|
84
|
+
# Use parent's run if available, otherwise create new
|
|
85
|
+
if parent_run_id:
|
|
86
|
+
wandb_run_id = parent_run_id
|
|
87
|
+
else:
|
|
88
|
+
wandb_run_id = user_provided_id or f"{run_name}-{action_name}"
|
|
89
|
+
|
|
90
|
+
return f"{self.host}/{wandb_entity}/{wandb_project}/runs/{wandb_run_id}"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class WandbSweep(Link):
|
|
95
|
+
"""
|
|
96
|
+
Generates a Weights & Biases Sweep link.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
host: Base W&B host URL
|
|
100
|
+
project: W&B project name (overrides context config if provided)
|
|
101
|
+
entity: W&B entity/team name (overrides context config if provided)
|
|
102
|
+
id: Optional W&B sweep ID (overrides context config if provided)
|
|
103
|
+
name: Link name in the Flyte UI
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
host: str = "https://wandb.ai"
|
|
107
|
+
project: Optional[str] = None
|
|
108
|
+
entity: Optional[str] = None
|
|
109
|
+
id: Optional[str] = None
|
|
110
|
+
name: str = "Weights & Biases Sweep"
|
|
111
|
+
|
|
112
|
+
def get_link(
|
|
113
|
+
self,
|
|
114
|
+
run_name: str,
|
|
115
|
+
project: str,
|
|
116
|
+
domain: str,
|
|
117
|
+
context: Dict[str, str],
|
|
118
|
+
parent_action_name: str,
|
|
119
|
+
action_name: str,
|
|
120
|
+
pod_name: str,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> str:
|
|
123
|
+
# Get project and entity from decorator values or context
|
|
124
|
+
wandb_project = self.project
|
|
125
|
+
wandb_entity = self.entity
|
|
126
|
+
sweep_id = self.id # Prioritize ID provided at link creation time
|
|
127
|
+
|
|
128
|
+
if context:
|
|
129
|
+
# Try to get from context config if not provided at decoration time
|
|
130
|
+
if not wandb_project:
|
|
131
|
+
wandb_project = context.get("wandb_project")
|
|
132
|
+
if not wandb_entity:
|
|
133
|
+
wandb_entity = context.get("wandb_entity")
|
|
134
|
+
|
|
135
|
+
# Try to get the sweep_id from context if not provided at link creation
|
|
136
|
+
# Child tasks inherit this from the parent that created the sweep
|
|
137
|
+
if not sweep_id:
|
|
138
|
+
sweep_id = context.get("_wandb_sweep_id")
|
|
139
|
+
|
|
140
|
+
# If we don't have project/entity, return base URL
|
|
141
|
+
if not wandb_project or not wandb_entity:
|
|
142
|
+
return self.host
|
|
143
|
+
|
|
144
|
+
# If we have a sweep_id, link to specific sweep
|
|
145
|
+
if sweep_id:
|
|
146
|
+
return f"{self.host}/{wandb_entity}/{wandb_project}/sweeps/{sweep_id}"
|
|
147
|
+
|
|
148
|
+
# No sweep_id: link to the project's sweeps list page
|
|
149
|
+
return f"{self.host}/{wandb_entity}/{wandb_project}/sweeps"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flyteplugins-wandb
|
|
3
|
+
Version: 2.0.0b52
|
|
4
|
+
Summary: Weights & Biases plugin for Flyte
|
|
5
|
+
Author: Flyte Contributors
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: wandb
|
|
9
|
+
Requires-Dist: flyte
|
|
10
|
+
|
|
11
|
+
# Weights & Biases Plugin
|
|
12
|
+
|
|
13
|
+
- Tasks decorated with `@wandb_init` or `@wandb_sweep` automatically get W&B links in the Flyte UI that point directly to the corresponding W&B runs or sweeps. Links retrieve project/entity from decorator parameters or context configuration (from `with_runcontext`).
|
|
14
|
+
- `@wandb_init` and `@wandb_sweep` must be the **outermost decorators** (applied after `@env.task`). For example:
|
|
15
|
+
|
|
16
|
+
```python
|
|
17
|
+
@wandb_init
|
|
18
|
+
@env.task
|
|
19
|
+
def my_task():
|
|
20
|
+
...
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
- By default (`run_mode="auto"`), child tasks automatically reuse their parent's W&B run if one exists, or create a new run if they're top-level tasks. You can override this with `run_mode="new"` (always create new) or `run_mode="shared"` (always reuse parent).
|
|
24
|
+
- `@wandb_init` should be applied to tasks (not traces). Traces can access the parent task's W&B run via `get_wandb_run()`. `@wandb_init` can also be applied to regular Python functions for use in `wandb.agent()` sweep callbacks.
|
|
25
|
+
- The wandb run can be accessed via `get_wandb_run()`, which returns the run object or `None` if not within a `@wandb_init` decorated task.
|
|
26
|
+
- When using `run_mode="shared"` or `run_mode="auto"` (with a parent run), child tasks reuse the parent's run ID. Configuration from `wandb_config()` is merged with decorator parameters.
|
|
27
|
+
- `wandb_config` can be used to pass configuration to tasks enclosed within the context manager and can also be provided via `with_runcontext`.
|
|
28
|
+
- When the context manager exits, the configuration falls back to the parent task's config.
|
|
29
|
+
- Arguments passed to `wandb_init` decorator are available only within the current task and traces and are not propagated to child tasks (use `wandb_config` for child tasks).
|
|
30
|
+
- At most 20 sweep agents can be launched at a time: https://docs.wandb.ai/models/sweeps/existing-project#3-launch-agents.
|
|
31
|
+
- `@wandb_sweep` creates a W&B sweep and adds a sweep link to the decorated task. The sweep ID is available via `get_wandb_sweep_id()`. For the parent task that creates the sweep, the link points to the project's sweeps list page. For child tasks, the link points to the specific sweep (they inherit the `sweep_id` from the parent's context).
|
|
32
|
+
- The objective function passed to `wandb.agent()` should be a vanilla Python function decorated with `@wandb_init` to initialize the run. You can access the run with `wandb.run` since the Flyte context won't be available during the objective function call.
|
|
33
|
+
- Set `download_logs=True` in `wandb_config` or `@wandb_init` to download W&B run logs after task completion. The I/O of this download functionality is traced by Flyte's `@flyte.trace`.
|
|
34
|
+
- Set `download_logs=True` in `wandb_sweep_config` or `@wandb_sweep` to download W&B sweep logs after task completion. The I/O of this download functionality is traced by Flyte's `@flyte.trace`.
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
flyteplugins/wandb/__init__.py,sha256=D5gqDOIy6ePcE2tcbNVsp9ZzxZKC6Qmd-6eHxNX3L88,15881
|
|
2
|
+
flyteplugins/wandb/_context.py,sha256=va_TlRhSW-QBbHhvKmIAggsLw5VFAq4gXMIu7n5ZKSA,12746
|
|
3
|
+
flyteplugins/wandb/_decorator.py,sha256=HenEVJI7kmDMQdHo6jDy3vXvjxT89CCYRBCR2CuGE3s,14785
|
|
4
|
+
flyteplugins/wandb/_link.py,sha256=tEzfW06GPsVMECGAnEhwNzCI2h0d0UnJHMqso6t8Pnw,5319
|
|
5
|
+
flyteplugins_wandb-2.0.0b52.dist-info/METADATA,sha256=oOQOpcjQa99Iy-bhYuop9KtFktttFzNzNYj4yQvUjBc,3058
|
|
6
|
+
flyteplugins_wandb-2.0.0b52.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
7
|
+
flyteplugins_wandb-2.0.0b52.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
|
|
8
|
+
flyteplugins_wandb-2.0.0b52.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flyteplugins
|