robo-goggles 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,253 @@
1
+ """WandB integration handler for Goggles logging framework."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, ClassVar, Dict, FrozenSet, Literal, Mapping, Optional
7
+ from typing_extensions import Self
8
+
9
+ import wandb
10
+
11
+ from goggles.types import Kind
12
+
13
+ Run = Any # wandb.sdk.wandb_run.Run
14
+ Reinit = Literal["default", "return_previous", "finish_previous", "create_new"]
15
+
16
+
17
+ class WandBHandler:
18
+ """Forward Goggles events to W&B runs (supports concurrent scopes).
19
+
20
+ Each scope corresponds to a distinct W&B run that remains active until
21
+ explicitly closed. Compatible with the `Handler` protocol used by the
22
+ EventBus.
23
+
24
+ Attributes:
25
+ name (str): Stable handler identifier.
26
+ capabilities (set[str]): Supported event kinds
27
+ ({"metric", "image", "video", "artifact"}).
28
+
29
+ """
30
+
31
+ name: str = "wandb"
32
+ capabilities: ClassVar[FrozenSet[Kind]] = frozenset(
33
+ {"metric", "image", "video", "artifact"}
34
+ )
35
+ GLOBAL_SCOPE: ClassVar[str] = "global"
36
+
37
+ def __init__(
38
+ self,
39
+ project: Optional[str] = None,
40
+ entity: Optional[str] = None,
41
+ run_name: Optional[str] = None,
42
+ config: Optional[Mapping[str, Any]] = None,
43
+ group: Optional[str] = None,
44
+ reinit: Reinit = "create_new",
45
+ ) -> None:
46
+ """Initialize the W&B handler.
47
+
48
+ Args:
49
+ project (Optional[str]): W&B project name.
50
+ entity (Optional[str]): W&B entity (user or team) name.
51
+ run_name (Optional[str]): Base name for W&B runs.
52
+ config (Optional[Mapping[str, Any]]): Configuration dictionary
53
+ to log with the run(s).
54
+ group (Optional[str]): W&B group name for runs.
55
+ reinit (Reinit): W&B reinitialization strategy when opening runs.
56
+ One of {"finish_previous", "return_previous", "create_new", "default"}.
57
+
58
+ """
59
+ self._logger = logging.getLogger(self.name)
60
+ self._logger.propagate = True
61
+ valid_reinit = {"finish_previous", "return_previous", "create_new", "default"}
62
+ if reinit not in valid_reinit:
63
+ raise ValueError(
64
+ f"Invalid reinit value '{reinit}'. Must be one of: "
65
+ f"{', '.join(valid_reinit)}."
66
+ )
67
+
68
+ self._project = project
69
+ self._entity = entity
70
+ self._group = group
71
+ self._base_run_name = run_name
72
+ self._config: Dict[str, Any] = dict(config) if config is not None else {}
73
+ self._reinit = reinit or "finish_previous"
74
+ self._runs: Dict[str, Run] = {}
75
+ self._wandb_run: Optional[Run] = None
76
+ self._current_scope: Optional[str] = None
77
+
78
+ def can_handle(self, kind: str) -> bool:
79
+ """Return True if the handler supports this event kind.
80
+
81
+ Args:
82
+ kind (str): Kind of event ("log", "metric", "image", "artifact").
83
+
84
+ Returns:
85
+ bool: True if the kind is supported, False otherwise.
86
+
87
+ """
88
+ return kind in self.capabilities
89
+
90
+ def open(self) -> None:
91
+ """Initialize the global W&B run."""
92
+ if self._wandb_run is not None:
93
+ return
94
+ self._wandb_run = wandb.init(
95
+ project=self._project,
96
+ entity=self._entity,
97
+ name=self._base_run_name,
98
+ config=self._config,
99
+ reinit=self._reinit, # type: ignore
100
+ group=self._group,
101
+ )
102
+ self._runs[self.GLOBAL_SCOPE] = self._wandb_run
103
+ self._current_scope = self.GLOBAL_SCOPE
104
+
105
+ def handle(self, event: Any) -> None:
106
+ """Process a Goggles event and forward it to W&B.
107
+
108
+ Args:
109
+ event (Any): The Goggles event to process.
110
+
111
+ """
112
+ scope = getattr(event, "scope", None) or self.GLOBAL_SCOPE
113
+ kind = getattr(event, "kind", None) or "metric"
114
+ step = getattr(event, "step", None)
115
+ payload = getattr(event, "payload", None)
116
+ extra = getattr(event, "extra", {}) or {}
117
+
118
+ run = self._get_or_create_run(scope)
119
+
120
+ if kind == "metric":
121
+ if not isinstance(payload, Mapping):
122
+ raise ValueError(
123
+ "Metric event payload must be a mapping of name→value."
124
+ )
125
+ run.log(dict(payload), step=step)
126
+ return
127
+
128
+ if kind in {"image", "video"}:
129
+ # Preferred key name comes from event.extra["name"], else "image"/"video"
130
+ default_key = "image" if kind == "image" else "video"
131
+ key_name = extra.get("name", default_key)
132
+
133
+ # Allow payload to be either a mapping {name: data} or a single datum
134
+ items = (
135
+ payload.items()
136
+ if isinstance(payload, Mapping)
137
+ else [(key_name, payload)]
138
+ )
139
+
140
+ logs = {}
141
+ for name, value in items:
142
+ if value is None:
143
+ self._logger.warning(
144
+ "Skipping %s '%s' with None payload (scope=%s).",
145
+ kind,
146
+ name,
147
+ scope,
148
+ )
149
+ continue
150
+ if kind == "image":
151
+ logs[name] = wandb.Image(value)
152
+ else:
153
+ fps = int(extra.get("fps", 20))
154
+ fmt = str(extra.get("format", "mp4"))
155
+ if fmt not in {"mp4", "gif"}:
156
+ self._logger.warning(
157
+ "Unsupported video format '%s' for '%s'; defaulting to 'mp4'.",
158
+ fmt,
159
+ name,
160
+ )
161
+ fmt = "mp4"
162
+ logs[name] = wandb.Video(value, fps=fps, format=fmt) # type: ignore
163
+
164
+ if logs:
165
+ # Use a single API across kinds for consistency
166
+ run.log(logs, step=step)
167
+ return
168
+
169
+ if kind == "artifact":
170
+ if not isinstance(payload, Mapping):
171
+ self._logger.warning(
172
+ "Artifact payload must be a mapping; got %r", type(payload)
173
+ )
174
+ return
175
+ path = payload.get("path")
176
+ name = payload.get("name", "artifact")
177
+ art_type = payload.get("type", "misc")
178
+ if not isinstance(path, str):
179
+ self._logger.warning("Artifact missing valid 'path' field; skipping.")
180
+ return
181
+ artifact = wandb.Artifact(name=name, type=art_type)
182
+ artifact.add_file(path)
183
+ run.log_artifact(artifact)
184
+ return
185
+
186
+ self._logger.warning("Unsupported event kind: %s", kind)
187
+
188
+ def close(self) -> None:
189
+ """Finish all active W&B runs."""
190
+ for scope, run in list(self._runs.items()):
191
+ if run is not None:
192
+ try:
193
+ run.finish()
194
+ except:
195
+ pass
196
+ self._runs.clear()
197
+ self._wandb_run = None
198
+ self._current_scope = None
199
+
200
+ def to_dict(self) -> Dict:
201
+ """Serialize the handler for attachment."""
202
+ return {
203
+ "cls": self.__class__.__name__,
204
+ "data": {
205
+ "project": self._project,
206
+ "entity": self._entity,
207
+ "run_name": self._base_run_name,
208
+ "config": self._config,
209
+ "reinit": self._reinit,
210
+ "group": self._group,
211
+ },
212
+ }
213
+
214
+ @classmethod
215
+ def from_dict(cls, serialized: Dict) -> Self:
216
+ """De-serialize the handler from its dictionary representation."""
217
+ return cls(
218
+ project=serialized.get("project"),
219
+ entity=serialized.get("entity"),
220
+ run_name=serialized.get("run_name"),
221
+ config=serialized.get("config"),
222
+ reinit=serialized.get("reinit", "create_new"),
223
+ group=serialized.get("group"),
224
+ )
225
+
226
+ def _get_or_create_run(self, scope: str) -> Run:
227
+ """Get or create a W&B run for the given scope.
228
+
229
+ Args:
230
+ scope (str): The scope for which to get or create the W&B run.
231
+
232
+ Returns:
233
+ Run: The W&B run associated with the given scope.
234
+
235
+ """
236
+ run = self._runs.get(scope)
237
+ if run is not None:
238
+ return run
239
+ name = (
240
+ self._base_run_name
241
+ if scope == self.GLOBAL_SCOPE and self._base_run_name
242
+ else f"{self._base_run_name or 'run'}-{scope}"
243
+ )
244
+ run = wandb.init(
245
+ project=self._project,
246
+ entity=self._entity,
247
+ name=name,
248
+ config={**self._config, "scope": scope},
249
+ group=self._group,
250
+ reinit=self._reinit, # type: ignore
251
+ )
252
+ self._runs[scope] = run
253
+ return run