flyteplugins-wandb 2.0.0b52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ import json
2
+ from dataclasses import asdict, dataclass
3
+ from typing import Any, Literal, Optional
4
+
5
+ import flyte
6
+
7
+ RunMode = Literal["auto", "new", "shared"]
8
+
9
+
10
+ def _to_dict_helper(obj, prefix: str) -> dict[str, str]:
11
+ """Convert dataclass to string dict for Flyte's custom_context."""
12
+ result = {}
13
+ for key, value in asdict(obj).items():
14
+ if value is not None:
15
+ if isinstance(value, (list, dict, bool)):
16
+ # Use json.dumps for lists, dicts, and bools for proper serialization
17
+ try:
18
+ result[f"{prefix}_{key}"] = json.dumps(value)
19
+ except (TypeError, ValueError) as e:
20
+ raise ValueError(
21
+ f"wandb config field '{key}' must be JSON-serializable. "
22
+ f"Got type: {type(value).__name__}. Error: {e}"
23
+ ) from e
24
+ else:
25
+ result[f"{prefix}_{key}"] = str(value)
26
+ return result
27
+
28
+
29
+ def _from_dict_helper(cls, d: dict[str, str], prefix: str):
30
+ """Create dataclass from custom_context dict."""
31
+ kwargs = {}
32
+ prefix_with_underscore = f"{prefix}_"
33
+ prefix_len = len(prefix_with_underscore)
34
+
35
+ # Exclude keys that match longer/more-specific prefixes
36
+ # (e.g., when processing "wandb", exclude "wandb_sweep")
37
+ exclude_prefixes = []
38
+ if prefix == "wandb":
39
+ exclude_prefixes = ["wandb_sweep_"]
40
+
41
+ for key, value in d.items():
42
+ if key.startswith(prefix_with_underscore):
43
+ # Skip if this key matches a more specific prefix
44
+ if any(key.startswith(excl) for excl in exclude_prefixes):
45
+ continue
46
+
47
+ field_name = key[prefix_len:]
48
+ try:
49
+ kwargs[field_name] = json.loads(value)
50
+ except (json.JSONDecodeError, TypeError):
51
+ kwargs[field_name] = value
52
+ return cls(**kwargs)
53
+
54
+
55
+ def _context_manager_enter(obj, prefix: str):
56
+ """Generic __enter__ for wandb config context managers."""
57
+ ctx = flyte.ctx()
58
+ saved_config = {}
59
+ if ctx and ctx.custom_context:
60
+ for key in list(ctx.custom_context.keys()):
61
+ if key.startswith(f"{prefix}_"):
62
+ saved_config[key] = ctx.custom_context[key]
63
+
64
+ ctx_mgr = flyte.custom_context(**obj)
65
+ ctx_mgr.__enter__()
66
+ return saved_config, ctx_mgr
67
+
68
+
69
+ def _context_manager_exit(ctx_mgr, saved_config: dict, prefix: str, *args):
70
+ """Generic __exit__ for wandb config context managers."""
71
+ if ctx_mgr:
72
+ ctx_mgr.__exit__(*args)
73
+
74
+ ctx = flyte.ctx()
75
+ if ctx and ctx.custom_context:
76
+ for key in list(ctx.custom_context.keys()):
77
+ if key.startswith(f"{prefix}_"):
78
+ del ctx.custom_context[key]
79
+ ctx.custom_context.update(saved_config)
80
+
81
+
82
+ @dataclass
83
+ class _WandBConfig:
84
+ """
85
+ Pass any other wandb.init() parameters via kwargs dict:
86
+ - notes, job_type, save_code
87
+ - resume, resume_from, fork_from, reinit
88
+ - anonymous, allow_val_change, force
89
+ - settings, and more
90
+
91
+ See: https://docs.wandb.ai/ref/python/init
92
+ """
93
+
94
+ # Essential fields (most commonly used)
95
+ project: Optional[str] = None
96
+ entity: Optional[str] = None
97
+ id: Optional[str] = None
98
+ name: Optional[str] = None
99
+ tags: Optional[list[str]] = None
100
+ config: Optional[dict[str, Any]] = None
101
+
102
+ # Common optional fields
103
+ mode: Optional[str] = None
104
+ group: Optional[str] = None
105
+
106
+ # Flyte-specific run mode (not passed to wandb.init)
107
+ # Controls whether to create a new W&B run or share an existing one
108
+ run_mode: RunMode = "auto" # "auto", "new", or "shared"
109
+
110
+ # Flyte-specific: download wandb logs after task completes
111
+ download_logs: bool = False
112
+
113
+ # Catch-all for additional wandb.init() parameters
114
+ kwargs: Optional[dict[str, Any]] = None
115
+
116
+ def to_dict(self) -> dict[str, str]:
117
+ """Convert to string dict for Flyte's custom_context."""
118
+ return _to_dict_helper(self, "wandb")
119
+
120
+ @classmethod
121
+ def from_dict(cls, d: dict[str, str]) -> "_WandBConfig":
122
+ """Create from custom_context dict."""
123
+ return _from_dict_helper(cls, d, "wandb")
124
+
125
+ # Dict protocol - for ** unpacking
126
+ def keys(self):
127
+ return self.to_dict().keys()
128
+
129
+ def __getitem__(self, key):
130
+ return self.to_dict()[key]
131
+
132
+ def __setitem__(self, key, value):
133
+ # For setting items, we need to update the actual Flyte context
134
+ ctx = flyte.ctx()
135
+ if ctx and ctx.custom_context:
136
+ ctx.custom_context[key] = value
137
+
138
+ def __delitem__(self, key):
139
+ # For deleting items, we need to update the actual Flyte context
140
+ ctx = flyte.ctx()
141
+ if ctx and ctx.custom_context:
142
+ del ctx.custom_context[key]
143
+
144
+ def items(self):
145
+ return self.to_dict().items()
146
+
147
+ def get(self, key, default=None):
148
+ return self.to_dict().get(key, default)
149
+
150
+ def pop(self, key, default=None):
151
+ # For popping items, we need to update the actual Flyte context
152
+ ctx = flyte.ctx()
153
+ if ctx and ctx.custom_context:
154
+ return ctx.custom_context.pop(key, default)
155
+ return default
156
+
157
+ def update(self, *args, **kwargs):
158
+ # For updating, we need to update the actual Flyte context
159
+ ctx = flyte.ctx()
160
+ if ctx and ctx.custom_context:
161
+ ctx.custom_context.update(*args, **kwargs)
162
+
163
+ # Context manager implementation
164
+ def __enter__(self):
165
+ self._saved_config, self._ctx = _context_manager_enter(self, "wandb")
166
+ return self
167
+
168
+ def __exit__(self, *args):
169
+ _context_manager_exit(self._ctx, self._saved_config, "wandb", *args)
170
+
171
+
172
+ def get_wandb_context() -> Optional[_WandBConfig]:
173
+ """Get wandb config from current Flyte context."""
174
+ ctx = flyte.ctx()
175
+ if ctx is None or not ctx.custom_context:
176
+ return None
177
+
178
+ # Check if we have wandb_ prefixed keys
179
+ has_wandb_keys = any(k.startswith("wandb_") for k in ctx.custom_context.keys())
180
+ if not has_wandb_keys:
181
+ return None
182
+
183
+ return _WandBConfig.from_dict(ctx.custom_context)
184
+
185
+
186
+ def wandb_config(
187
+ project: Optional[str] = None,
188
+ entity: Optional[str] = None,
189
+ id: Optional[str] = None,
190
+ name: Optional[str] = None,
191
+ tags: Optional[list[str]] = None,
192
+ config: Optional[dict[str, Any]] = None,
193
+ mode: Optional[str] = None,
194
+ group: Optional[str] = None,
195
+ run_mode: RunMode = "auto",
196
+ download_logs: bool = False,
197
+ **kwargs: Any,
198
+ ) -> _WandBConfig:
199
+ """
200
+ Create wandb configuration.
201
+
202
+ This function works in two contexts:
203
+ 1. With `flyte.with_runcontext()` - sets global wandb config
204
+ 2. As a context manager - overrides config for specific tasks
205
+
206
+ Args:
207
+ project: W&B project name
208
+ entity: W&B entity (team or username)
209
+ id: Unique run id (auto-generated if not provided)
210
+ name: Human-readable run name
211
+ tags: List of tags for organizing runs
212
+ config: Dictionary of hyperparameters
213
+ mode: "online", "offline" or "disabled"
214
+ group: Group name for related runs
215
+ run_mode: Flyte-specific run mode - "auto", "new" or "shared".
216
+ Controls whether tasks create new W&B runs or share existing ones
217
+ download_logs: If `True`, downloads wandb run files after task completes
218
+ and shows them as a trace output in the Flyte UI
219
+ **kwargs: Additional `wandb.init()` parameters
220
+ """
221
+ return _WandBConfig(
222
+ project=project,
223
+ entity=entity,
224
+ id=id,
225
+ name=name,
226
+ tags=tags,
227
+ config=config,
228
+ mode=mode,
229
+ group=group,
230
+ run_mode=run_mode,
231
+ download_logs=download_logs,
232
+ kwargs=kwargs if kwargs else None,
233
+ )
234
+
235
+
236
+ @dataclass
237
+ class _WandBSweepConfig:
238
+ # Essential sweep parameters
239
+ name: Optional[str] = None
240
+ method: Optional[str] = None
241
+ metric: Optional[dict[str, Any]] = None
242
+ parameters: Optional[dict[str, Any]] = None
243
+
244
+ # Sweep metadata
245
+ project: Optional[str] = None
246
+ entity: Optional[str] = None
247
+ prior_runs: Optional[list[str]] = None
248
+
249
+ # Flyte-specific: download wandb sweep logs after task completes
250
+ download_logs: bool = False
251
+
252
+ # Catch-all for additional sweep config parameters
253
+ # (e.g. early_terminate, name, description, command, controller, etc.)
254
+ kwargs: Optional[dict[str, Any]] = None
255
+
256
+ def to_sweep_config(self) -> dict[str, Any]:
257
+ """Convert to wandb.sweep() compatible dict."""
258
+ config = asdict(self)
259
+
260
+ # Remove fields that aren't part of the sweep config
261
+ config.pop("project", None)
262
+ config.pop("entity", None)
263
+ config.pop("prior_runs", None)
264
+ config.pop("download_logs", None)
265
+
266
+ # Merge kwargs into the main config
267
+ extra_kwargs = config.pop("kwargs", None)
268
+ if extra_kwargs:
269
+ config.update(extra_kwargs)
270
+
271
+ # Remove None values
272
+ return {k: v for k, v in config.items() if v is not None}
273
+
274
+ def to_dict(self) -> dict[str, str]:
275
+ """Convert to string dict for Flyte's custom_context."""
276
+ return _to_dict_helper(self, "wandb_sweep")
277
+
278
+ @classmethod
279
+ def from_dict(cls, d: dict[str, str]) -> "_WandBSweepConfig":
280
+ """Create from custom_context dict."""
281
+ return _from_dict_helper(cls, d, "wandb_sweep")
282
+
283
+ # Dict protocol - for ** unpacking
284
+ def keys(self):
285
+ return self.to_dict().keys()
286
+
287
+ def __getitem__(self, key):
288
+ return self.to_dict()[key]
289
+
290
+ def __setitem__(self, key, value):
291
+ # For setting items, we need to update the actual Flyte context
292
+ ctx = flyte.ctx()
293
+ if ctx and ctx.custom_context:
294
+ ctx.custom_context[key] = value
295
+
296
+ def __delitem__(self, key):
297
+ # For deleting items, we need to update the actual Flyte context
298
+ ctx = flyte.ctx()
299
+ if ctx and ctx.custom_context:
300
+ del ctx.custom_context[key]
301
+
302
+ def items(self):
303
+ return self.to_dict().items()
304
+
305
+ def get(self, key, default=None):
306
+ return self.to_dict().get(key, default)
307
+
308
+ def pop(self, key, default=None):
309
+ # For popping items, we need to update the actual Flyte context
310
+ ctx = flyte.ctx()
311
+ if ctx and ctx.custom_context:
312
+ return ctx.custom_context.pop(key, default)
313
+ return default
314
+
315
+ def update(self, *args, **kwargs):
316
+ # For updating, we need to update the actual Flyte context
317
+ ctx = flyte.ctx()
318
+ if ctx and ctx.custom_context:
319
+ ctx.custom_context.update(*args, **kwargs)
320
+
321
+ # Context manager implementation
322
+ def __enter__(self):
323
+ self._saved_config, self._ctx = _context_manager_enter(self, "wandb_sweep")
324
+ return self
325
+
326
+ def __exit__(self, *args):
327
+ _context_manager_exit(self._ctx, self._saved_config, "wandb_sweep", *args)
328
+
329
+
330
+ def get_wandb_sweep_context() -> Optional[_WandBSweepConfig]:
331
+ """Get wandb sweep config from current Flyte context."""
332
+ ctx = flyte.ctx()
333
+ if ctx is None or not ctx.custom_context:
334
+ return None
335
+
336
+ has_wandb_sweep_keys = any(k.startswith("wandb_sweep_") for k in ctx.custom_context.keys())
337
+ if not has_wandb_sweep_keys:
338
+ return None
339
+
340
+ return _WandBSweepConfig.from_dict(ctx.custom_context)
341
+
342
+
343
+ def wandb_sweep_config(
344
+ method: Optional[str] = None,
345
+ metric: Optional[dict[str, Any]] = None,
346
+ parameters: Optional[dict[str, Any]] = None,
347
+ project: Optional[str] = None,
348
+ entity: Optional[str] = None,
349
+ prior_runs: Optional[list[str]] = None,
350
+ name: Optional[str] = None,
351
+ download_logs: bool = False,
352
+ **kwargs: Any,
353
+ ) -> _WandBSweepConfig:
354
+ """
355
+ Create wandb sweep configuration for hyperparameter optimization.
356
+
357
+ Args:
358
+ method: Sweep method (e.g., "random", "grid", "bayes")
359
+ metric: Metric to optimize (e.g., {"name": "loss", "goal": "minimize"})
360
+ parameters: Parameter definitions for the sweep
361
+ project: W&B project for the sweep
362
+ entity: W&B entity for the sweep
363
+ prior_runs: List of prior run IDs to include in the sweep analysis
364
+ name: Sweep name (auto-generated as `{run_name}-{action_name}` if not provided)
365
+ download_logs: If `True`, downloads all sweep run files after task completes
366
+ and shows them as a trace output in the Flyte UI
367
+ **kwargs: additional sweep config parameters like `early_terminate`, `description`, `command`, etc.
368
+
369
+ See: https://docs.wandb.ai/models/sweeps/sweep-config-keys
370
+ """
371
+ return _WandBSweepConfig(
372
+ name=name,
373
+ method=method,
374
+ metric=metric,
375
+ parameters=parameters,
376
+ project=project,
377
+ entity=entity,
378
+ prior_runs=prior_runs,
379
+ download_logs=download_logs,
380
+ kwargs=kwargs if kwargs else None,
381
+ )