flyteplugins-wandb 2.0.0b52__py3-none-any.whl → 2.0.0b54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@
6
6
  - Parent/child task support with automatic run reuse
7
7
  - W&B sweep creation and management with `@wandb_sweep` decorator
8
8
  - Configuration management with `wandb_config()` and `wandb_sweep_config()`
9
+ - Distributed training support (auto-detects PyTorch DDP/torchrun)
9
10
 
10
11
  ## Basic usage:
11
12
 
@@ -119,6 +120,64 @@
119
120
  ).run(run_parallel_sweep, num_agents=2, trials_per_agent=5)
120
121
  ```
121
122
 
123
+ 6. Distributed Training Support:
124
+
125
+ The plugin auto-detects distributed training from environment variables
126
+ (RANK, WORLD_SIZE, LOCAL_RANK, etc.) set by torchrun/torch.distributed.elastic.
127
+
128
+ By default (`run_mode="auto"`):
129
+ - Single-node: Only rank 0 logs (1 run)
130
+ - Multi-node: Local rank 0 of each worker logs (1 run per worker)
131
+
132
+ ```python
133
+ from flyteplugins.pytorch.task import Elastic
134
+ from flyteplugins.wandb import wandb_init, get_wandb_run
135
+
136
+ torch_env = flyte.TaskEnvironment(
137
+ name="torch_env",
138
+ resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "5Gi"), gpu="V100:4"),
139
+ plugin_config=Elastic(nnodes=2, nproc_per_node=2),
140
+ )
141
+
142
+ @wandb_init
143
+ @torch_env.task
144
+ async def train_distributed():
145
+ torch.distributed.init_process_group("nccl")
146
+
147
+ # Only local rank 0 gets a W&B run, other ranks get None
148
+ run = get_wandb_run()
149
+ if run:
150
+ run.log({"loss": loss})
151
+
152
+ return run.id if run else "non-primary-rank"
153
+ ```
154
+
155
+ Use `run_mode="shared"` for all ranks to log to a single shared run:
156
+
157
+ ```python
158
+ @wandb_init(run_mode="shared")
159
+ @torch_env.task
160
+ async def train_distributed_shared():
161
+ # All ranks log to the same W&B run (with x_label to identify each rank)
162
+ run = get_wandb_run()
163
+ run.log({"rank_metric": value})
164
+ return run.id
165
+ ```
166
+
167
+ Use `run_mode="new"` for each rank to have its own W&B run:
168
+
169
+ ```python
170
+ @wandb_init(run_mode="new")
171
+ @torch_env.task
172
+ async def train_distributed_separate_runs():
173
+ # Each rank gets its own W&B run (grouped in W&B UI)
174
+ # Run IDs: {run_name}-{action_name}-rank-{rank} (single-node)
175
+ # Run IDs: {run_name}-{action_name}-worker-{worker}-rank-{rank} (multi-node)
176
+ run = get_wandb_run()
177
+ run.log({"rank_metric": value})
178
+ return run.id
179
+ ```
180
+
122
181
  Decorator order: `@wandb_init` or `@wandb_sweep` must be the outermost decorator:
123
182
 
124
183
  ```python
@@ -145,7 +204,11 @@ from ._context import (
145
204
  wandb_config,
146
205
  wandb_sweep_config,
147
206
  )
148
- from ._decorator import wandb_init, wandb_sweep
207
+ from ._decorator import (
208
+ _get_distributed_info,
209
+ wandb_init,
210
+ wandb_sweep,
211
+ )
149
212
  from ._link import Wandb, WandbSweep
150
213
 
151
214
  logger = logging.getLogger(__name__)
@@ -158,6 +221,7 @@ __all__ = [
158
221
  "download_wandb_run_logs",
159
222
  "download_wandb_sweep_dirs",
160
223
  "download_wandb_sweep_logs",
224
+ "get_distributed_info",
161
225
  "get_wandb_context",
162
226
  "get_wandb_run",
163
227
  "get_wandb_run_dir",
@@ -183,11 +247,15 @@ def get_wandb_run():
183
247
  Returns:
184
248
  `wandb.sdk.wandb_run.Run` | `None`: The current wandb run object or None.
185
249
  """
250
+ # First check Flyte context
186
251
  ctx = flyte.ctx()
187
- if not ctx or not ctx.data:
188
- return None
252
+ if ctx and ctx.data:
253
+ run = ctx.data.get("_wandb_run")
254
+ if run:
255
+ return run
189
256
 
190
- return ctx.data.get("_wandb_run")
257
+ # Fallback to wandb's global run
258
+ return wandb.run
191
259
 
192
260
 
193
261
  def get_wandb_sweep_id() -> str | None:
@@ -224,6 +292,25 @@ def get_wandb_run_dir() -> Optional[str]:
224
292
  return run.dir
225
293
 
226
294
 
295
+ def get_distributed_info() -> dict | None:
296
+ """
297
+ Get distributed training info if running in a distributed context.
298
+
299
+ This function auto-detects distributed training from environment variables
300
+ set by torchrun/torch.distributed.elastic.
301
+
302
+ Returns:
303
+ dict | None: Dictionary with distributed info or None if not distributed.
304
+ - rank: Global rank (0 to world_size-1)
305
+ - local_rank: Rank within the node (0 to local_world_size-1)
306
+ - world_size: Total number of processes
307
+ - local_world_size: Processes per node
308
+ - worker_index: Node/worker index (0 to num_workers-1)
309
+ - num_workers: Total number of nodes/workers
310
+ """
311
+ return _get_distributed_info()
312
+
313
+
227
314
  def download_wandb_run_dir(
228
315
  run_id: Optional[str] = None,
229
316
  path: Optional[str] = None,
@@ -213,7 +213,12 @@ def wandb_config(
213
213
  mode: "online", "offline" or "disabled"
214
214
  group: Group name for related runs
215
215
  run_mode: Flyte-specific run mode - "auto", "new" or "shared".
216
- Controls whether tasks create new W&B runs or share existing ones
216
+ Controls whether tasks create new W&B runs or share existing ones.
217
+ In distributed training context:
218
+ - "auto" (default): Single-node: only rank 0 logs.
219
+ Multi-node: local rank 0 of each worker logs (1 run per worker).
220
+ - "shared": All ranks log to a single shared W&B run.
221
+ - "new": Each rank gets its own W&B run (grouped in W&B UI).
217
222
  download_logs: If `True`, downloads wandb run files after task completes
218
223
  and shows them as a trace output in the Flyte UI
219
224
  **kwargs: Additional `wandb.init()` parameters
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import logging
3
+ import os
3
4
  from contextlib import contextmanager
4
5
  from dataclasses import asdict
5
6
  from inspect import iscoroutinefunction
@@ -18,6 +19,144 @@ logger = logging.getLogger(__name__)
18
19
  F = TypeVar("F", bound=Callable[..., Any])
19
20
 
20
21
 
22
+ def _get_distributed_info() -> dict | None:
23
+ """
24
+ Auto-detect distributed training info from environment variables.
25
+
26
+ Returns None if not in a distributed training context.
27
+ Environment variables are set by torchrun/torch.distributed.elastic.
28
+ """
29
+ if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
30
+ return None
31
+
32
+ world_size = int(os.environ["WORLD_SIZE"])
33
+ if world_size <= 1:
34
+ return None
35
+
36
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", world_size))
37
+
38
+ return {
39
+ "rank": int(os.environ["RANK"]),
40
+ "local_rank": int(os.environ.get("LOCAL_RANK", "0")),
41
+ "world_size": world_size,
42
+ "local_world_size": local_world_size,
43
+ "worker_index": int(os.environ.get("GROUP_RANK", "0")),
44
+ "num_workers": world_size // local_world_size if local_world_size > 0 else 1,
45
+ }
46
+
47
+
48
+ def _is_multi_node(info: dict) -> bool:
49
+ """Check if this is a multi-node distributed setup."""
50
+ return info["num_workers"] > 1
51
+
52
+
53
+ def _is_primary_rank(info: dict) -> bool:
54
+ """Check if current process is rank 0 (primary)."""
55
+ return info["rank"] == 0
56
+
57
+
58
+ def _should_skip_rank(run_mode: RunMode, dist_info: dict) -> bool:
59
+ """
60
+ Check if this rank should skip wandb initialization.
61
+
62
+ For run_mode="auto":
63
+ - Single-node: Only rank 0 initializes wandb
64
+ - Multi-node: Only local rank 0 of each worker initializes wandb
65
+
66
+ For run_mode="shared" or "new": All ranks initialize wandb.
67
+ """
68
+ if run_mode != "auto":
69
+ return False
70
+
71
+ is_multi_node = _is_multi_node(dist_info)
72
+ is_primary = _is_primary_rank(dist_info)
73
+ is_local_primary = dist_info["local_rank"] == 0
74
+
75
+ if is_multi_node:
76
+ # Multi-node: only local rank 0 of each node logs
77
+ return not is_local_primary
78
+ else:
79
+ # Single-node: only rank 0 logs
80
+ return not is_primary
81
+
82
+
83
+ def _configure_distributed_run(
84
+ init_kwargs: dict,
85
+ run_mode: RunMode,
86
+ dist_info: dict,
87
+ base_run_id: str,
88
+ ) -> dict:
89
+ """
90
+ Configure wandb.init() kwargs for distributed training.
91
+
92
+ Sets run ID, group, and shared mode settings based on:
93
+ - run_mode: "auto", "new", or "shared"
94
+ - dist_info: distributed topology (rank, worker_index, etc.)
95
+ - base_run_id: base string for generating run IDs
96
+
97
+ Run ID patterns:
98
+ - Single-node auto/shared: {base_run_id}
99
+ - Single-node new: {base_run_id}-rank-{rank}
100
+ - Multi-node auto/shared: {base_run_id}-worker-{worker_index}
101
+ - Multi-node new: {base_run_id}-worker-{worker_index}-rank-{local_rank}
102
+ """
103
+ is_multi_node = _is_multi_node(dist_info)
104
+ is_primary = _is_primary_rank(dist_info)
105
+
106
+ # Build run ID based on mode and topology
107
+ if "id" not in init_kwargs or init_kwargs["id"] is None:
108
+ if run_mode == "new":
109
+ # Each rank gets its own run
110
+ if is_multi_node:
111
+ init_kwargs["id"] = f"{base_run_id}-worker-{dist_info['worker_index']}-rank-{dist_info['local_rank']}"
112
+ else:
113
+ init_kwargs["id"] = f"{base_run_id}-rank-{dist_info['rank']}"
114
+ else: # run_mode == "auto" or "shared"
115
+ if is_multi_node:
116
+ init_kwargs["id"] = f"{base_run_id}-worker-{dist_info['worker_index']}"
117
+ else:
118
+ init_kwargs["id"] = base_run_id
119
+
120
+ # Set group for multiple runs (run_mode="new")
121
+ if run_mode == "new" and "group" not in init_kwargs:
122
+ if is_multi_node:
123
+ init_kwargs["group"] = f"{base_run_id}-worker-{dist_info['worker_index']}"
124
+ else:
125
+ init_kwargs["group"] = base_run_id
126
+
127
+ # Configure W&B shared mode for run_mode="shared"
128
+ if run_mode == "shared":
129
+ if is_multi_node:
130
+ x_label = f"worker-{dist_info['worker_index']}-rank-{dist_info['local_rank']}"
131
+ # For multi-node, primary is local_rank 0 within each worker
132
+ is_worker_primary = dist_info["local_rank"] == 0
133
+ else:
134
+ x_label = f"rank-{dist_info['rank']}"
135
+ # For single-node, primary is rank 0
136
+ is_worker_primary = is_primary
137
+
138
+ existing_settings = init_kwargs.get("settings")
139
+ shared_config = {
140
+ "mode": "shared",
141
+ "x_primary": is_worker_primary,
142
+ "x_label": x_label,
143
+ "x_update_finish_state": is_worker_primary,
144
+ }
145
+
146
+ # Handle both dict and wandb.Settings objects
147
+ if existing_settings is None:
148
+ init_kwargs["settings"] = wandb.Settings(**shared_config)
149
+ elif isinstance(existing_settings, dict):
150
+ init_kwargs["settings"] = wandb.Settings(**{**existing_settings, **shared_config})
151
+ else:
152
+ # existing_settings is already a wandb.Settings object
153
+ for key, value in shared_config.items():
154
+ setattr(existing_settings, key, value)
155
+ init_kwargs["settings"] = existing_settings
156
+
157
+ return init_kwargs
158
+
159
+
21
160
  def _build_init_kwargs() -> dict[str, Any]:
22
161
  """Build wandb.init() kwargs from current context config."""
23
162
  context_config = get_wandb_context()
@@ -50,6 +189,7 @@ def _wandb_run(
50
189
  """
51
190
  # Try to get Flyte context
52
191
  ctx = flyte.ctx()
192
+ dist_info = _get_distributed_info()
53
193
 
54
194
  # This enables @wandb_init to work in wandb.agent() callbacks (sweep objectives)
55
195
  if func and ctx is None:
@@ -61,6 +201,12 @@ def _wandb_run(
61
201
  run.finish()
62
202
  return
63
203
  elif func and ctx:
204
+ # Check if there's already a W&B run from parent
205
+ existing_run = ctx.data.get("_wandb_run")
206
+ if existing_run:
207
+ yield existing_run
208
+ return
209
+
64
210
  raise RuntimeError(
65
211
  "@wandb_init cannot be applied to traces. Traces can access the parent's wandb run via get_wandb_run()."
66
212
  )
@@ -85,46 +231,64 @@ def _wandb_run(
85
231
 
86
232
  # Get current action name for run ID generation
87
233
  current_action = ctx.action.name
234
+ base_run_id = f"{ctx.action.run_name}-{current_action}"
235
+
236
+ # Handle distributed training
237
+ if dist_info:
238
+ if _should_skip_rank(run_mode, dist_info):
239
+ yield None
240
+ return
241
+
242
+ init_kwargs = _configure_distributed_run(init_kwargs, run_mode, dist_info, base_run_id)
243
+ else:
244
+ # Non-distributed training
245
+ # Determine if we should reuse parent's run
246
+ should_reuse = False
247
+ if run_mode == "shared":
248
+ should_reuse = True
249
+ elif run_mode == "auto":
250
+ should_reuse = bool(saved_run_id)
251
+
252
+ # Determine run ID
253
+ if "id" not in init_kwargs or init_kwargs["id"] is None:
254
+ if should_reuse:
255
+ if not saved_run_id:
256
+ raise RuntimeError("Cannot reuse parent run: no parent run ID found")
257
+ init_kwargs["id"] = saved_run_id
258
+ else:
259
+ init_kwargs["id"] = base_run_id
88
260
 
89
- # Determine if we should reuse parent's run
90
- should_reuse = False
91
- if run_mode == "shared":
92
- should_reuse = True
93
- elif run_mode == "auto":
94
- should_reuse = bool(saved_run_id)
95
-
96
- # Determine run ID
97
- if "id" not in init_kwargs or init_kwargs["id"] is None:
98
- if should_reuse:
99
- if not saved_run_id:
100
- raise RuntimeError("Cannot reuse parent run: no parent run ID found")
101
- init_kwargs["id"] = saved_run_id
102
- else:
103
- init_kwargs["id"] = f"{ctx.action.run_name}-{current_action}"
104
-
105
- # Configure reinit parameter (only for local mode)
106
- # In remote/shared mode, wandb handles run creation/joining automatically
107
- if flyte.ctx().mode == "local":
108
- if should_reuse:
109
- if "reinit" not in init_kwargs:
110
- init_kwargs["reinit"] = "return_previous"
111
- else:
112
- init_kwargs["reinit"] = "create_new"
113
-
114
- # Configure remote mode settings
115
- if flyte.ctx().mode == "remote":
116
- is_primary = not should_reuse
117
- existing_settings = init_kwargs.get("settings", {})
118
-
119
- shared_config = {
120
- "mode": "shared",
121
- "x_primary": is_primary,
122
- "x_label": current_action,
123
- }
124
- if not is_primary:
125
- shared_config["x_update_finish_state"] = False
126
-
127
- init_kwargs["settings"] = wandb.Settings(**{**existing_settings, **shared_config})
261
+ # Configure reinit parameter (only for local mode)
262
+ if ctx.mode == "local":
263
+ if should_reuse:
264
+ if "reinit" not in init_kwargs:
265
+ init_kwargs["reinit"] = "return_previous"
266
+ else:
267
+ init_kwargs["reinit"] = "create_new"
268
+
269
+ # Configure remote mode settings
270
+ if ctx.mode == "remote":
271
+ is_primary = not should_reuse
272
+ existing_settings = init_kwargs.get("settings")
273
+
274
+ shared_config = {
275
+ "mode": "shared",
276
+ "x_primary": is_primary,
277
+ "x_label": current_action,
278
+ }
279
+ if not is_primary:
280
+ shared_config["x_update_finish_state"] = False
281
+
282
+ # Handle None, dict, and wandb.Settings objects
283
+ if existing_settings is None:
284
+ init_kwargs["settings"] = wandb.Settings(**shared_config)
285
+ elif isinstance(existing_settings, dict):
286
+ init_kwargs["settings"] = wandb.Settings(**{**existing_settings, **shared_config})
287
+ else:
288
+ # existing_settings is already a wandb.Settings object
289
+ for key, value in shared_config.items():
290
+ setattr(existing_settings, key, value)
291
+ init_kwargs["settings"] = existing_settings
128
292
 
129
293
  # Initialize wandb
130
294
  run = wandb.init(**init_kwargs)
@@ -141,18 +305,18 @@ def _wandb_run(
141
305
  # Determine if this is a primary run
142
306
  is_primary_run = run_mode == "new" or (run_mode == "auto" and saved_run_id is None)
143
307
 
308
+ # Determine if we should call finish()
309
+ should_finish = False
144
310
  if run:
145
- # Different cleanup logic for local vs remote mode
146
- should_finish = False
147
-
148
- if flyte.ctx().mode == "remote":
149
- # In remote/shared mode, always call run.finish() to flush data
150
- # For secondary tasks, x_update_finish_state=False prevents actually finishing
151
- # For primary tasks, this properly finishes the run
152
- should_finish = True
153
- elif is_primary_run:
154
- # In local mode, only primary tasks should call run.finish()
155
- # Secondary tasks reuse the parent's run object, so they must not finish it
311
+ if dist_info and run_mode == "shared":
312
+ # For distributed shared mode, only primary (local_rank 0) finishes
313
+ is_multi_node = _is_multi_node(dist_info)
314
+ if is_multi_node:
315
+ should_finish = dist_info["local_rank"] == 0
316
+ else:
317
+ should_finish = dist_info["rank"] == 0
318
+ elif ctx.mode == "remote" or is_primary_run:
319
+ # In remote mode or for primary runs, always finish
156
320
  should_finish = True
157
321
 
158
322
  if should_finish:
@@ -192,10 +356,14 @@ def wandb_init(
192
356
 
193
357
  Args:
194
358
  run_mode: Controls whether to create a new W&B run or share an existing one:
195
-
196
- 1. "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run
197
- 2. "new": Always creates a new wandb run with a unique ID
198
- 3. "shared": Always shares the parent's run ID (useful for child tasks)
359
+ - "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run
360
+ - "new": Always creates a new wandb run with a unique ID
361
+ - "shared": Always shares the parent's run ID (useful for child tasks)
362
+ In distributed training context:
363
+ - "auto" (default): Single-node: only rank 0 logs.
364
+ Multi-node: local rank 0 of each worker logs (1 run per worker).
365
+ - "shared": All ranks log to a single shared W&B run.
366
+ - "new": Each rank gets its own W&B run (grouped in W&B UI).
199
367
  download_logs: If `True`, downloads wandb run files after task completes
200
368
  and shows them as a trace output in the Flyte UI. If None, uses
201
369
  the value from `wandb_config()` context if set.
@@ -230,15 +398,59 @@ def wandb_init(
230
398
 
231
399
  # Check if it's a Flyte task (AsyncFunctionTaskTemplate)
232
400
  if isinstance(func, AsyncFunctionTaskTemplate):
233
- # Create a Wandb link
234
- # Even if run_mode="shared", we still add a link - it will point to the parent's run
235
- wandb_link = Wandb(project=project, entity=entity, run_mode=run_mode)
401
+ # Detect distributed config from plugin_config
402
+ nnodes = 1
403
+ nproc_per_node = 1
404
+ plugin_config = getattr(func, "plugin_config", None)
405
+
406
+ if plugin_config is not None and type(plugin_config).__name__ == "Elastic":
407
+ nnodes_val = getattr(plugin_config, "nnodes", 1)
408
+ if isinstance(nnodes_val, int):
409
+ nnodes = nnodes_val
410
+ elif isinstance(nnodes_val, str):
411
+ parts = nnodes_val.split(":")
412
+ nnodes = int(parts[-1]) if parts else 1
413
+
414
+ nproc_val = getattr(plugin_config, "nproc_per_node", 1)
415
+ if isinstance(nproc_val, int):
416
+ nproc_per_node = nproc_val
417
+ elif isinstance(nproc_val, str):
418
+ try:
419
+ nproc_per_node = int(nproc_val)
420
+ except ValueError:
421
+ nproc_per_node = 1
236
422
 
237
- # Get existing links from the task and add wandb link
423
+ is_distributed = nnodes > 1 or nproc_per_node > 1
424
+
425
+ # Add W&B links
426
+ wandb_id = kwargs.get("id")
238
427
  existing_links = getattr(func, "links", ())
239
428
 
240
- # Use override to properly add the link to the task
241
- func = func.override(links=(*existing_links, wandb_link))
429
+ if nnodes > 1:
430
+ # Multi-node: one link per worker
431
+ wandb_links = tuple(
432
+ Wandb(
433
+ project=project,
434
+ entity=entity,
435
+ run_mode=run_mode,
436
+ id=wandb_id,
437
+ _is_distributed=True,
438
+ _worker_index=i,
439
+ name=f"Weights & Biases Worker {i}",
440
+ )
441
+ for i in range(nnodes)
442
+ )
443
+ func = func.override(links=(*existing_links, *wandb_links))
444
+ else:
445
+ # Single-node (distributed or not): one link
446
+ wandb_link = Wandb(
447
+ project=project,
448
+ entity=entity,
449
+ run_mode=run_mode,
450
+ id=wandb_id,
451
+ _is_distributed=is_distributed,
452
+ )
453
+ func = func.override(links=(*existing_links, wandb_link))
242
454
 
243
455
  # Wrap the task's execute method with wandb_run
244
456
  original_execute = func.execute
@@ -15,11 +15,15 @@ class Wandb(Link):
15
15
  host: Base W&B host URL
16
16
  project: W&B project name (overrides context config if provided)
17
17
  entity: W&B entity/team name (overrides context config if provided)
18
- run_mode: Controls whether to create a new W&B run or share an existing one:
19
-
20
- 1. "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run
21
- 2. "new": Always creates a new wandb run with a unique ID
22
- 3. "shared": Always shares the parent's run ID (useful for child tasks)
18
+ run_mode: Determines the link behavior:
19
+ - "auto" (default): Use parent's run if available, otherwise create new
20
+ - "new": Always creates a new wandb run with a unique ID
21
+ - "shared": Always shares the parent's run ID (useful for child tasks)
22
+ In distributed training context:
23
+ - "auto" (default): Single-node: only rank 0 logs
24
+ Multi-node: only local rank 0 of each worker logs
25
+ - "shared": Link to a single shared W&B run.
26
+ - "new": Link to group view.
23
27
  id: Optional W&B run ID (overrides context config if provided)
24
28
  name: Link name in the Flyte UI
25
29
  """
@@ -30,6 +34,10 @@ class Wandb(Link):
30
34
  run_mode: RunMode = "auto"
31
35
  id: Optional[str] = None
32
36
  name: str = "Weights & Biases"
37
+ # Internal: set by @wandb_init for distributed training tasks
38
+ _is_distributed: bool = False
39
+ # Internal: worker index for multi-node distributed training (set by @wandb_init)
40
+ _worker_index: Optional[int] = None
33
41
 
34
42
  def get_link(
35
43
  self,
@@ -69,6 +77,35 @@ class Wandb(Link):
69
77
  if not wandb_project or not wandb_entity:
70
78
  return self.host
71
79
 
80
+ # Distributed training links - derived from decorator-time info (plugin_config)
81
+ # _is_distributed and _worker_index are set by @wandb_init based on Elastic config
82
+ is_multi_node = self._worker_index is not None
83
+
84
+ if self._is_distributed:
85
+ base_id = user_provided_id or f"{run_name}-{action_name}"
86
+
87
+ # For run_mode="new", link to group view
88
+ if run_mode == "new":
89
+ if is_multi_node:
90
+ # Multi-node: link to per-worker group
91
+ group_name = f"{base_id}-worker-{self._worker_index}"
92
+ else:
93
+ # Single-node: link to single group
94
+ group_name = base_id
95
+
96
+ return f"{self.host}/{wandb_entity}/{wandb_project}/groups/{group_name}"
97
+
98
+ # For run_mode="auto" or "shared", link to run directly
99
+ if is_multi_node:
100
+ # Multi-node: link to worker-specific run
101
+ wandb_run_id = f"{base_id}-worker-{self._worker_index}"
102
+ else:
103
+ # Single-node: link to single run
104
+ wandb_run_id = base_id
105
+
106
+ return f"{self.host}/{wandb_entity}/{wandb_project}/runs/{wandb_run_id}"
107
+
108
+ # Non-distributed: link to specific run
72
109
  # Determine run ID based on run_mode setting
73
110
  if run_mode == "new":
74
111
  # Always create new run - use user-provided ID if available, otherwise generate
@@ -0,0 +1,266 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-wandb
3
+ Version: 2.0.0b54
4
+ Summary: Weights & Biases plugin for Flyte
5
+ Author: Flyte Contributors
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: wandb
9
+ Requires-Dist: flyte
10
+
11
+ # Weights & Biases Plugin
12
+
13
+ This plugin provides integration between Flyte and Weights & Biases (W&B) for experiment tracking, including support for distributed training with PyTorch Elastic.
14
+
15
+ ## Quickstart
16
+
17
+ ```python
18
+ from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run
19
+
20
+ @wandb_init(project="my-project", entity="my-team")
21
+ @env.task
22
+ def train():
23
+ run = get_wandb_run()
24
+ run.log({"loss": 0.5, "accuracy": 0.9})
25
+ ```
26
+
27
+ ## Core concepts
28
+
29
+ ### Decorator order
30
+
31
+ `@wandb_init` and `@wandb_sweep` must be the **outermost decorators** (applied after `@env.task`):
32
+
33
+ ```python
34
+ @wandb_init # Outermost
35
+ @env.task # Task decorator
36
+ def my_task():
37
+ ...
38
+ ```
39
+
40
+ ### Run modes
41
+
42
+ The `run_mode` parameter controls how W&B runs are created:
43
+
44
+ - **`"auto"`** (default): Creates a new run if no parent exists, otherwise shares the parent's run
45
+ - **`"new"`**: Always creates a new W&B run with a unique ID
46
+ - **`"shared"`**: Always shares the parent's run ID (useful for child tasks)
47
+
48
+ ### Accessing the run
49
+
50
+ Use `get_wandb_run()` to access the current W&B run:
51
+
52
+ ```python
53
+ from flyteplugins.wandb import get_wandb_run
54
+
55
+ run = get_wandb_run()
56
+ if run:
57
+ run.log({"metric": value})
58
+ ```
59
+
60
+ Returns `None` if not within a `@wandb_init` decorated task or if the current rank should not log (in distributed training).
61
+
62
+ ## Distributed training
63
+
64
+ The plugin automatically detects distributed training environments (PyTorch Elastic) and configures W&B appropriately.
65
+
66
+ ### Environment variables
67
+
68
+ Distributed training is detected via these environment variables (set by `torchrun`/`torch.distributed.elastic`):
69
+
70
+ | Variable | Description |
71
+ |----------|-------------|
72
+ | `RANK` | Global rank of the process |
73
+ | `WORLD_SIZE` | Total number of processes |
74
+ | `LOCAL_RANK` | Rank within the current node |
75
+ | `LOCAL_WORLD_SIZE` | Number of processes per node |
76
+ | `GROUP_RANK` | Worker/node index (0, 1, 2, ...) |
77
+
78
+ ### Run modes in distributed context
79
+
80
+ | Mode | Single-Node | Multi-Node |
81
+ |------|-------------|------------|
82
+ | `"auto"` | Only rank 0 logs → 1 run | Local rank 0 of each worker logs → N runs (1 per worker) |
83
+ | `"shared"` | All ranks log to 1 shared run | All ranks per worker log to shared run → N runs (1 per worker) |
84
+ | `"new"` | Each rank gets its own run (grouped) → N runs | Each rank gets its own run (grouped per worker) → N×GPUs runs |
85
+
86
+ ### Run ID patterns
87
+
88
+ | Scenario | Run ID Pattern |
89
+ |----------|----------------|
90
+ | Single-node auto/shared | `{run_name}-{action_name}` |
91
+ | Single-node new | `{run_name}-{action_name}-rank-{rank}` |
92
+ | Multi-node auto/shared | `{run_name}-{action_name}-worker-{worker_index}` |
93
+ | Multi-node new | `{run_name}-{action_name}-worker-{worker_index}-rank-{local_rank}` |
94
+
95
+ ### Example: Distributed training task
96
+
97
+ ```python
98
+ from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run, get_distributed_info
99
+ from flyteplugins.pytorch.task import Elastic
100
+
101
+ # Multi-node environment (2 nodes, 4 GPUs each)
102
+ multi_node_env = flyte.TaskEnvironment(
103
+ name="multi_node_env",
104
+ resources=flyte.Resources(gpu="V100:4", shm="auto"),
105
+ plugin_config=Elastic(nproc_per_node=4, nnodes=2),
106
+ secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
107
+ )
108
+
109
+ @wandb_init # run_mode="auto" by default
110
+ @multi_node_env.task
111
+ def train_multi_node():
112
+ import torch.distributed as dist
113
+ dist.init_process_group("nccl")
114
+
115
+ run = get_wandb_run() # Returns run for local_rank 0, None for others
116
+ dist_info = get_distributed_info()
117
+
118
+ # Training loop...
119
+ if run:
120
+ run.log({"loss": loss.item()})
121
+
122
+ dist.destroy_process_group()
123
+ ```
124
+
125
+ ### Shared mode for all-Rank logging
126
+
127
+ Use `run_mode="shared"` when you want all ranks to log to the same W&B run:
128
+
129
+ ```python
130
+ @wandb_init(run_mode="shared")
131
+ @multi_node_env.task
132
+ def train_all_ranks_log():
133
+ run = get_wandb_run() # All ranks get a run object
134
+
135
+ # All ranks can log - W&B handles deduplication
136
+ run.log({"loss": loss.item(), "rank": dist.get_rank()})
137
+ ```
138
+
139
+ ### New mode for per-rank runs
140
+
141
+ Use `run_mode="new"` when you want each rank to have its own W&B run:
142
+
143
+ ```python
144
+ @wandb_init(run_mode="new")
145
+ @multi_node_env.task
146
+ def train_per_rank():
147
+ run = get_wandb_run() # Each rank gets its own run
148
+
149
+ # Runs are grouped in W&B UI for easy comparison
150
+ run.log({"loss": loss.item()})
151
+ ```
152
+
153
+ ## Configuration
154
+
155
+ ### wandb_config
156
+
157
+ Use `wandb_config()` to pass configuration that propagates to child tasks:
158
+
159
+ ```python
160
+ from flyteplugins.wandb import wandb_config
161
+
162
+ # With flyte.with_runcontext
163
+ run = flyte.with_runcontext(
164
+ custom_context=wandb_config(
165
+ project="my-project",
166
+ entity="my-team",
167
+ tags=["experiment-1"],
168
+ )
169
+ ).run(my_task)
170
+
171
+ # As a context manager
172
+ with wandb_config(project="override-project"):
173
+ await child_task()
174
+ ```
175
+
176
+ ### Decorator vs context config
177
+
178
+ - **Decorator arguments** (`@wandb_init(project=...)`) are available only within the current task and its traces
179
+ - **Context config** (`wandb_config(...)`) propagates to child tasks
180
+
181
+ ## W&B links
182
+
183
+ Tasks decorated with `@wandb_init` or `@wandb_sweep` automatically get W&B links in the Flyte UI:
184
+
185
+ - For distributed training with multiple workers, each worker gets its own link
186
+ - Links point directly to the corresponding W&B runs or sweeps
187
+ - Project/entity are retrieved from decorator parameters or context configuration
188
+
189
+ ## Sweeps
190
+
191
+ Use `@wandb_sweep` to create W&B sweeps:
192
+
193
+ ```python
194
+ from flyteplugins.wandb import wandb_sweep, wandb_sweep_config, get_wandb_sweep_id
195
+
196
+ @wandb_init
197
+ def objective():
198
+ # Training logic - this runs for each sweep trial
199
+ run = get_wandb_run()
200
+ config = run.config # Sweep parameters are passed via run.config
201
+
202
+ # Train with sweep-suggested hyperparameters
203
+ model = train(lr=config.lr, batch_size=config.batch_size)
204
+ wandb.log({"loss": loss, "accuracy": accuracy})
205
+
206
+ @wandb_sweep
207
+ @env.task
208
+ def run_sweep():
209
+ sweep_id = get_wandb_sweep_id()
210
+
211
+ # Launch sweep agents to run trials
212
+ # count=10 means run 10 trials total
213
+ wandb.agent(sweep_id, function=objective, count=10)
214
+ ```
215
+
216
+ **Note:** A maximum of **20 sweep agents** can be launched at a time.
217
+
218
+ Configure sweeps with `wandb_sweep_config()`:
219
+
220
+ ```python
221
+ run = flyte.with_runcontext(
222
+ custom_context=wandb_sweep_config(
223
+ method="bayes",
224
+ metric={"name": "loss", "goal": "minimize"},
225
+ parameters={"lr": {"min": 1e-5, "max": 1e-2}},
226
+ project="my-project",
227
+ )
228
+ ).run(run_sweep)
229
+ ```
230
+
231
+ ## Downloading logs
232
+
233
+ Set `download_logs=True` to download W&B run/sweep logs after task completion. The download I/O is traced by Flyte's `@flyte.trace`, making the logs visible in the Flyte UI:
234
+
235
+ ```python
236
+ @wandb_init(download_logs=True)
237
+ @env.task
238
+ def train():
239
+ ...
240
+
241
+ # Or via context
242
+ wandb_config(download_logs=True)
243
+ wandb_sweep_config(download_logs=True)
244
+ ```
245
+
246
+ The downloaded logs include all files uploaded to W&B during the run (metrics, artifacts, etc.).
247
+
248
+ ## API reference
249
+
250
+ ### Functions
251
+
252
+ - `get_wandb_run()` - Get the current W&B run object (or `None`)
253
+ - `get_wandb_sweep_id()` - Get the current sweep ID (or `None`)
254
+ - `get_distributed_info()` - Get distributed training info dict (or `None`)
255
+ - `wandb_config(...)` - Create W&B configuration for context
256
+ - `wandb_sweep_config(...)` - Create sweep configuration for context
257
+
258
+ ### Decorators
259
+
260
+ - `@wandb_init` - Initialize W&B for a task or function
261
+ - `@wandb_sweep` - Create a W&B sweep for a task
262
+
263
+ ### Links
264
+
265
+ - `Wandb` - Link class for W&B runs
266
+ - `WandbSweep` - Link class for W&B sweeps
@@ -0,0 +1,8 @@
1
+ flyteplugins/wandb/__init__.py,sha256=kfCwLEFT2foHYaPrqtpOMnV_qwVM7qw9IPFC2lJUhBM,18704
2
+ flyteplugins/wandb/_context.py,sha256=7-MnHpoh4OdjKWI6SWYljLqa2cS-wDA6PmDVGudyDHY,13073
3
+ flyteplugins/wandb/_decorator.py,sha256=ge3ZT3AxFNDWtOCzhoRvJsM2WPjzQmX4SV3rhu1vHdc,22749
4
+ flyteplugins/wandb/_link.py,sha256=dQfH9BoI0eMwm-hf_rh9aFZ5bgYZWDFCQ6jU5TgA214,7027
5
+ flyteplugins_wandb-2.0.0b54.dist-info/METADATA,sha256=RRYiuP4fX4giv7pNQhsVzinkwDJ_tvEFG0lDxhld23M,7710
6
+ flyteplugins_wandb-2.0.0b54.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
7
+ flyteplugins_wandb-2.0.0b54.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
8
+ flyteplugins_wandb-2.0.0b54.dist-info/RECORD,,
@@ -1,34 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: flyteplugins-wandb
3
- Version: 2.0.0b52
4
- Summary: Weights & Biases plugin for Flyte
5
- Author: Flyte Contributors
6
- Requires-Python: >=3.10
7
- Description-Content-Type: text/markdown
8
- Requires-Dist: wandb
9
- Requires-Dist: flyte
10
-
11
- # Weights & Biases Plugin
12
-
13
- - Tasks decorated with `@wandb_init` or `@wandb_sweep` automatically get W&B links in the Flyte UI that point directly to the corresponding W&B runs or sweeps. Links retrieve project/entity from decorator parameters or context configuration (from `with_runcontext`).
14
- - `@wandb_init` and `@wandb_sweep` must be the **outermost decorators** (applied after `@env.task`). For example:
15
-
16
- ```python
17
- @wandb_init
18
- @env.task
19
- def my_task():
20
- ...
21
- ```
22
-
23
- - By default (`run_mode="auto"`), child tasks automatically reuse their parent's W&B run if one exists, or create a new run if they're top-level tasks. You can override this with `run_mode="new"` (always create new) or `run_mode="shared"` (always reuse parent).
24
- - `@wandb_init` should be applied to tasks (not traces). Traces can access the parent task's W&B run via `get_wandb_run()`. `@wandb_init` can also be applied to regular Python functions for use in `wandb.agent()` sweep callbacks.
25
- - The wandb run can be accessed via `get_wandb_run()`, which returns the run object or `None` if not within a `@wandb_init` decorated task.
26
- - When using `run_mode="shared"` or `run_mode="auto"` (with a parent run), child tasks reuse the parent's run ID. Configuration from `wandb_config()` is merged with decorator parameters.
27
- - `wandb_config` can be used to pass configuration to tasks enclosed within the context manager and can also be provided via `with_runcontext`.
28
- - When the context manager exits, the configuration falls back to the parent task's config.
29
- - Arguments passed to `wandb_init` decorator are available only within the current task and traces and are not propagated to child tasks (use `wandb_config` for child tasks).
30
- - At most 20 sweep agents can be launched at a time: https://docs.wandb.ai/models/sweeps/existing-project#3-launch-agents.
31
- - `@wandb_sweep` creates a W&B sweep and adds a sweep link to the decorated task. The sweep ID is available via `get_wandb_sweep_id()`. For the parent task that creates the sweep, the link points to the project's sweeps list page. For child tasks, the link points to the specific sweep (they inherit the `sweep_id` from the parent's context).
32
- - The objective function passed to `wandb.agent()` should be a vanilla Python function decorated with `@wandb_init` to initialize the run. You can access the run with `wandb.run` since the Flyte context won't be available during the objective function call.
33
- - Set `download_logs=True` in `wandb_config` or `@wandb_init` to download W&B run logs after task completion. The I/O of this download functionality is traced by Flyte's `@flyte.trace`.
34
- - Set `download_logs=True` in `wandb_sweep_config` or `@wandb_sweep` to download W&B sweep logs after task completion. The I/O of this download functionality is traced by Flyte's `@flyte.trace`.
@@ -1,8 +0,0 @@
1
- flyteplugins/wandb/__init__.py,sha256=D5gqDOIy6ePcE2tcbNVsp9ZzxZKC6Qmd-6eHxNX3L88,15881
2
- flyteplugins/wandb/_context.py,sha256=va_TlRhSW-QBbHhvKmIAggsLw5VFAq4gXMIu7n5ZKSA,12746
3
- flyteplugins/wandb/_decorator.py,sha256=HenEVJI7kmDMQdHo6jDy3vXvjxT89CCYRBCR2CuGE3s,14785
4
- flyteplugins/wandb/_link.py,sha256=tEzfW06GPsVMECGAnEhwNzCI2h0d0UnJHMqso6t8Pnw,5319
5
- flyteplugins_wandb-2.0.0b52.dist-info/METADATA,sha256=oOQOpcjQa99Iy-bhYuop9KtFktttFzNzNYj4yQvUjBc,3058
6
- flyteplugins_wandb-2.0.0b52.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
7
- flyteplugins_wandb-2.0.0b52.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
8
- flyteplugins_wandb-2.0.0b52.dist-info/RECORD,,