flyteplugins-wandb 2.0.0b52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/wandb/__init__.py +493 -0
- flyteplugins/wandb/_context.py +381 -0
- flyteplugins/wandb/_decorator.py +417 -0
- flyteplugins/wandb/_link.py +149 -0
- flyteplugins_wandb-2.0.0b52.dist-info/METADATA +34 -0
- flyteplugins_wandb-2.0.0b52.dist-info/RECORD +8 -0
- flyteplugins_wandb-2.0.0b52.dist-info/WHEEL +5 -0
- flyteplugins_wandb-2.0.0b52.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
3
|
+
from typing import Any, Literal, Optional
|
|
4
|
+
|
|
5
|
+
import flyte
|
|
6
|
+
|
|
7
|
+
RunMode = Literal["auto", "new", "shared"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _to_dict_helper(obj, prefix: str) -> dict[str, str]:
|
|
11
|
+
"""Convert dataclass to string dict for Flyte's custom_context."""
|
|
12
|
+
result = {}
|
|
13
|
+
for key, value in asdict(obj).items():
|
|
14
|
+
if value is not None:
|
|
15
|
+
if isinstance(value, (list, dict, bool)):
|
|
16
|
+
# Use json.dumps for lists, dicts, and bools for proper serialization
|
|
17
|
+
try:
|
|
18
|
+
result[f"{prefix}_{key}"] = json.dumps(value)
|
|
19
|
+
except (TypeError, ValueError) as e:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"wandb config field '{key}' must be JSON-serializable. "
|
|
22
|
+
f"Got type: {type(value).__name__}. Error: {e}"
|
|
23
|
+
) from e
|
|
24
|
+
else:
|
|
25
|
+
result[f"{prefix}_{key}"] = str(value)
|
|
26
|
+
return result
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _from_dict_helper(cls, d: dict[str, str], prefix: str):
|
|
30
|
+
"""Create dataclass from custom_context dict."""
|
|
31
|
+
kwargs = {}
|
|
32
|
+
prefix_with_underscore = f"{prefix}_"
|
|
33
|
+
prefix_len = len(prefix_with_underscore)
|
|
34
|
+
|
|
35
|
+
# Exclude keys that match longer/more-specific prefixes
|
|
36
|
+
# (e.g., when processing "wandb", exclude "wandb_sweep")
|
|
37
|
+
exclude_prefixes = []
|
|
38
|
+
if prefix == "wandb":
|
|
39
|
+
exclude_prefixes = ["wandb_sweep_"]
|
|
40
|
+
|
|
41
|
+
for key, value in d.items():
|
|
42
|
+
if key.startswith(prefix_with_underscore):
|
|
43
|
+
# Skip if this key matches a more specific prefix
|
|
44
|
+
if any(key.startswith(excl) for excl in exclude_prefixes):
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
field_name = key[prefix_len:]
|
|
48
|
+
try:
|
|
49
|
+
kwargs[field_name] = json.loads(value)
|
|
50
|
+
except (json.JSONDecodeError, TypeError):
|
|
51
|
+
kwargs[field_name] = value
|
|
52
|
+
return cls(**kwargs)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _context_manager_enter(obj, prefix: str):
|
|
56
|
+
"""Generic __enter__ for wandb config context managers."""
|
|
57
|
+
ctx = flyte.ctx()
|
|
58
|
+
saved_config = {}
|
|
59
|
+
if ctx and ctx.custom_context:
|
|
60
|
+
for key in list(ctx.custom_context.keys()):
|
|
61
|
+
if key.startswith(f"{prefix}_"):
|
|
62
|
+
saved_config[key] = ctx.custom_context[key]
|
|
63
|
+
|
|
64
|
+
ctx_mgr = flyte.custom_context(**obj)
|
|
65
|
+
ctx_mgr.__enter__()
|
|
66
|
+
return saved_config, ctx_mgr
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _context_manager_exit(ctx_mgr, saved_config: dict, prefix: str, *args):
|
|
70
|
+
"""Generic __exit__ for wandb config context managers."""
|
|
71
|
+
if ctx_mgr:
|
|
72
|
+
ctx_mgr.__exit__(*args)
|
|
73
|
+
|
|
74
|
+
ctx = flyte.ctx()
|
|
75
|
+
if ctx and ctx.custom_context:
|
|
76
|
+
for key in list(ctx.custom_context.keys()):
|
|
77
|
+
if key.startswith(f"{prefix}_"):
|
|
78
|
+
del ctx.custom_context[key]
|
|
79
|
+
ctx.custom_context.update(saved_config)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class _WandBConfig:
|
|
84
|
+
"""
|
|
85
|
+
Pass any other wandb.init() parameters via kwargs dict:
|
|
86
|
+
- notes, job_type, save_code
|
|
87
|
+
- resume, resume_from, fork_from, reinit
|
|
88
|
+
- anonymous, allow_val_change, force
|
|
89
|
+
- settings, and more
|
|
90
|
+
|
|
91
|
+
See: https://docs.wandb.ai/ref/python/init
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
# Essential fields (most commonly used)
|
|
95
|
+
project: Optional[str] = None
|
|
96
|
+
entity: Optional[str] = None
|
|
97
|
+
id: Optional[str] = None
|
|
98
|
+
name: Optional[str] = None
|
|
99
|
+
tags: Optional[list[str]] = None
|
|
100
|
+
config: Optional[dict[str, Any]] = None
|
|
101
|
+
|
|
102
|
+
# Common optional fields
|
|
103
|
+
mode: Optional[str] = None
|
|
104
|
+
group: Optional[str] = None
|
|
105
|
+
|
|
106
|
+
# Flyte-specific run mode (not passed to wandb.init)
|
|
107
|
+
# Controls whether to create a new W&B run or share an existing one
|
|
108
|
+
run_mode: RunMode = "auto" # "auto", "new", or "shared"
|
|
109
|
+
|
|
110
|
+
# Flyte-specific: download wandb logs after task completes
|
|
111
|
+
download_logs: bool = False
|
|
112
|
+
|
|
113
|
+
# Catch-all for additional wandb.init() parameters
|
|
114
|
+
kwargs: Optional[dict[str, Any]] = None
|
|
115
|
+
|
|
116
|
+
def to_dict(self) -> dict[str, str]:
|
|
117
|
+
"""Convert to string dict for Flyte's custom_context."""
|
|
118
|
+
return _to_dict_helper(self, "wandb")
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def from_dict(cls, d: dict[str, str]) -> "_WandBConfig":
|
|
122
|
+
"""Create from custom_context dict."""
|
|
123
|
+
return _from_dict_helper(cls, d, "wandb")
|
|
124
|
+
|
|
125
|
+
# Dict protocol - for ** unpacking
|
|
126
|
+
def keys(self):
|
|
127
|
+
return self.to_dict().keys()
|
|
128
|
+
|
|
129
|
+
def __getitem__(self, key):
|
|
130
|
+
return self.to_dict()[key]
|
|
131
|
+
|
|
132
|
+
def __setitem__(self, key, value):
|
|
133
|
+
# For setting items, we need to update the actual Flyte context
|
|
134
|
+
ctx = flyte.ctx()
|
|
135
|
+
if ctx and ctx.custom_context:
|
|
136
|
+
ctx.custom_context[key] = value
|
|
137
|
+
|
|
138
|
+
def __delitem__(self, key):
|
|
139
|
+
# For deleting items, we need to update the actual Flyte context
|
|
140
|
+
ctx = flyte.ctx()
|
|
141
|
+
if ctx and ctx.custom_context:
|
|
142
|
+
del ctx.custom_context[key]
|
|
143
|
+
|
|
144
|
+
def items(self):
|
|
145
|
+
return self.to_dict().items()
|
|
146
|
+
|
|
147
|
+
def get(self, key, default=None):
|
|
148
|
+
return self.to_dict().get(key, default)
|
|
149
|
+
|
|
150
|
+
def pop(self, key, default=None):
|
|
151
|
+
# For popping items, we need to update the actual Flyte context
|
|
152
|
+
ctx = flyte.ctx()
|
|
153
|
+
if ctx and ctx.custom_context:
|
|
154
|
+
return ctx.custom_context.pop(key, default)
|
|
155
|
+
return default
|
|
156
|
+
|
|
157
|
+
def update(self, *args, **kwargs):
|
|
158
|
+
# For updating, we need to update the actual Flyte context
|
|
159
|
+
ctx = flyte.ctx()
|
|
160
|
+
if ctx and ctx.custom_context:
|
|
161
|
+
ctx.custom_context.update(*args, **kwargs)
|
|
162
|
+
|
|
163
|
+
# Context manager implementation
|
|
164
|
+
def __enter__(self):
|
|
165
|
+
self._saved_config, self._ctx = _context_manager_enter(self, "wandb")
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
def __exit__(self, *args):
|
|
169
|
+
_context_manager_exit(self._ctx, self._saved_config, "wandb", *args)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_wandb_context() -> Optional[_WandBConfig]:
|
|
173
|
+
"""Get wandb config from current Flyte context."""
|
|
174
|
+
ctx = flyte.ctx()
|
|
175
|
+
if ctx is None or not ctx.custom_context:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
# Check if we have wandb_ prefixed keys
|
|
179
|
+
has_wandb_keys = any(k.startswith("wandb_") for k in ctx.custom_context.keys())
|
|
180
|
+
if not has_wandb_keys:
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
return _WandBConfig.from_dict(ctx.custom_context)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def wandb_config(
|
|
187
|
+
project: Optional[str] = None,
|
|
188
|
+
entity: Optional[str] = None,
|
|
189
|
+
id: Optional[str] = None,
|
|
190
|
+
name: Optional[str] = None,
|
|
191
|
+
tags: Optional[list[str]] = None,
|
|
192
|
+
config: Optional[dict[str, Any]] = None,
|
|
193
|
+
mode: Optional[str] = None,
|
|
194
|
+
group: Optional[str] = None,
|
|
195
|
+
run_mode: RunMode = "auto",
|
|
196
|
+
download_logs: bool = False,
|
|
197
|
+
**kwargs: Any,
|
|
198
|
+
) -> _WandBConfig:
|
|
199
|
+
"""
|
|
200
|
+
Create wandb configuration.
|
|
201
|
+
|
|
202
|
+
This function works in two contexts:
|
|
203
|
+
1. With `flyte.with_runcontext()` - sets global wandb config
|
|
204
|
+
2. As a context manager - overrides config for specific tasks
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
project: W&B project name
|
|
208
|
+
entity: W&B entity (team or username)
|
|
209
|
+
id: Unique run id (auto-generated if not provided)
|
|
210
|
+
name: Human-readable run name
|
|
211
|
+
tags: List of tags for organizing runs
|
|
212
|
+
config: Dictionary of hyperparameters
|
|
213
|
+
mode: "online", "offline" or "disabled"
|
|
214
|
+
group: Group name for related runs
|
|
215
|
+
run_mode: Flyte-specific run mode - "auto", "new" or "shared".
|
|
216
|
+
Controls whether tasks create new W&B runs or share existing ones
|
|
217
|
+
download_logs: If `True`, downloads wandb run files after task completes
|
|
218
|
+
and shows them as a trace output in the Flyte UI
|
|
219
|
+
**kwargs: Additional `wandb.init()` parameters
|
|
220
|
+
"""
|
|
221
|
+
return _WandBConfig(
|
|
222
|
+
project=project,
|
|
223
|
+
entity=entity,
|
|
224
|
+
id=id,
|
|
225
|
+
name=name,
|
|
226
|
+
tags=tags,
|
|
227
|
+
config=config,
|
|
228
|
+
mode=mode,
|
|
229
|
+
group=group,
|
|
230
|
+
run_mode=run_mode,
|
|
231
|
+
download_logs=download_logs,
|
|
232
|
+
kwargs=kwargs if kwargs else None,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@dataclass
|
|
237
|
+
class _WandBSweepConfig:
|
|
238
|
+
# Essential sweep parameters
|
|
239
|
+
name: Optional[str] = None
|
|
240
|
+
method: Optional[str] = None
|
|
241
|
+
metric: Optional[dict[str, Any]] = None
|
|
242
|
+
parameters: Optional[dict[str, Any]] = None
|
|
243
|
+
|
|
244
|
+
# Sweep metadata
|
|
245
|
+
project: Optional[str] = None
|
|
246
|
+
entity: Optional[str] = None
|
|
247
|
+
prior_runs: Optional[list[str]] = None
|
|
248
|
+
|
|
249
|
+
# Flyte-specific: download wandb sweep logs after task completes
|
|
250
|
+
download_logs: bool = False
|
|
251
|
+
|
|
252
|
+
# Catch-all for additional sweep config parameters
|
|
253
|
+
# (e.g. early_terminate, name, description, command, controller, etc.)
|
|
254
|
+
kwargs: Optional[dict[str, Any]] = None
|
|
255
|
+
|
|
256
|
+
def to_sweep_config(self) -> dict[str, Any]:
|
|
257
|
+
"""Convert to wandb.sweep() compatible dict."""
|
|
258
|
+
config = asdict(self)
|
|
259
|
+
|
|
260
|
+
# Remove fields that aren't part of the sweep config
|
|
261
|
+
config.pop("project", None)
|
|
262
|
+
config.pop("entity", None)
|
|
263
|
+
config.pop("prior_runs", None)
|
|
264
|
+
config.pop("download_logs", None)
|
|
265
|
+
|
|
266
|
+
# Merge kwargs into the main config
|
|
267
|
+
extra_kwargs = config.pop("kwargs", None)
|
|
268
|
+
if extra_kwargs:
|
|
269
|
+
config.update(extra_kwargs)
|
|
270
|
+
|
|
271
|
+
# Remove None values
|
|
272
|
+
return {k: v for k, v in config.items() if v is not None}
|
|
273
|
+
|
|
274
|
+
def to_dict(self) -> dict[str, str]:
|
|
275
|
+
"""Convert to string dict for Flyte's custom_context."""
|
|
276
|
+
return _to_dict_helper(self, "wandb_sweep")
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def from_dict(cls, d: dict[str, str]) -> "_WandBSweepConfig":
|
|
280
|
+
"""Create from custom_context dict."""
|
|
281
|
+
return _from_dict_helper(cls, d, "wandb_sweep")
|
|
282
|
+
|
|
283
|
+
# Dict protocol - for ** unpacking
|
|
284
|
+
def keys(self):
|
|
285
|
+
return self.to_dict().keys()
|
|
286
|
+
|
|
287
|
+
def __getitem__(self, key):
|
|
288
|
+
return self.to_dict()[key]
|
|
289
|
+
|
|
290
|
+
def __setitem__(self, key, value):
|
|
291
|
+
# For setting items, we need to update the actual Flyte context
|
|
292
|
+
ctx = flyte.ctx()
|
|
293
|
+
if ctx and ctx.custom_context:
|
|
294
|
+
ctx.custom_context[key] = value
|
|
295
|
+
|
|
296
|
+
def __delitem__(self, key):
|
|
297
|
+
# For deleting items, we need to update the actual Flyte context
|
|
298
|
+
ctx = flyte.ctx()
|
|
299
|
+
if ctx and ctx.custom_context:
|
|
300
|
+
del ctx.custom_context[key]
|
|
301
|
+
|
|
302
|
+
def items(self):
|
|
303
|
+
return self.to_dict().items()
|
|
304
|
+
|
|
305
|
+
def get(self, key, default=None):
|
|
306
|
+
return self.to_dict().get(key, default)
|
|
307
|
+
|
|
308
|
+
def pop(self, key, default=None):
|
|
309
|
+
# For popping items, we need to update the actual Flyte context
|
|
310
|
+
ctx = flyte.ctx()
|
|
311
|
+
if ctx and ctx.custom_context:
|
|
312
|
+
return ctx.custom_context.pop(key, default)
|
|
313
|
+
return default
|
|
314
|
+
|
|
315
|
+
def update(self, *args, **kwargs):
|
|
316
|
+
# For updating, we need to update the actual Flyte context
|
|
317
|
+
ctx = flyte.ctx()
|
|
318
|
+
if ctx and ctx.custom_context:
|
|
319
|
+
ctx.custom_context.update(*args, **kwargs)
|
|
320
|
+
|
|
321
|
+
# Context manager implementation
|
|
322
|
+
def __enter__(self):
|
|
323
|
+
self._saved_config, self._ctx = _context_manager_enter(self, "wandb_sweep")
|
|
324
|
+
return self
|
|
325
|
+
|
|
326
|
+
def __exit__(self, *args):
|
|
327
|
+
_context_manager_exit(self._ctx, self._saved_config, "wandb_sweep", *args)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def get_wandb_sweep_context() -> Optional[_WandBSweepConfig]:
|
|
331
|
+
"""Get wandb sweep config from current Flyte context."""
|
|
332
|
+
ctx = flyte.ctx()
|
|
333
|
+
if ctx is None or not ctx.custom_context:
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
has_wandb_sweep_keys = any(k.startswith("wandb_sweep_") for k in ctx.custom_context.keys())
|
|
337
|
+
if not has_wandb_sweep_keys:
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
return _WandBSweepConfig.from_dict(ctx.custom_context)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def wandb_sweep_config(
|
|
344
|
+
method: Optional[str] = None,
|
|
345
|
+
metric: Optional[dict[str, Any]] = None,
|
|
346
|
+
parameters: Optional[dict[str, Any]] = None,
|
|
347
|
+
project: Optional[str] = None,
|
|
348
|
+
entity: Optional[str] = None,
|
|
349
|
+
prior_runs: Optional[list[str]] = None,
|
|
350
|
+
name: Optional[str] = None,
|
|
351
|
+
download_logs: bool = False,
|
|
352
|
+
**kwargs: Any,
|
|
353
|
+
) -> _WandBSweepConfig:
|
|
354
|
+
"""
|
|
355
|
+
Create wandb sweep configuration for hyperparameter optimization.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
method: Sweep method (e.g., "random", "grid", "bayes")
|
|
359
|
+
metric: Metric to optimize (e.g., {"name": "loss", "goal": "minimize"})
|
|
360
|
+
parameters: Parameter definitions for the sweep
|
|
361
|
+
project: W&B project for the sweep
|
|
362
|
+
entity: W&B entity for the sweep
|
|
363
|
+
prior_runs: List of prior run IDs to include in the sweep analysis
|
|
364
|
+
name: Sweep name (auto-generated as `{run_name}-{action_name}` if not provided)
|
|
365
|
+
download_logs: If `True`, downloads all sweep run files after task completes
|
|
366
|
+
and shows them as a trace output in the Flyte UI
|
|
367
|
+
**kwargs: additional sweep config parameters like `early_terminate`, `description`, `command`, etc.
|
|
368
|
+
|
|
369
|
+
See: https://docs.wandb.ai/models/sweeps/sweep-config-keys
|
|
370
|
+
"""
|
|
371
|
+
return _WandBSweepConfig(
|
|
372
|
+
name=name,
|
|
373
|
+
method=method,
|
|
374
|
+
metric=metric,
|
|
375
|
+
parameters=parameters,
|
|
376
|
+
project=project,
|
|
377
|
+
entity=entity,
|
|
378
|
+
prior_runs=prior_runs,
|
|
379
|
+
download_logs=download_logs,
|
|
380
|
+
kwargs=kwargs if kwargs else None,
|
|
381
|
+
)
|