robo-goggles 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,191 @@
1
+ """Utility functions for history slicing and inspection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional, Sequence, Tuple, Union
6
+
7
+ import jax
8
+ import jaxlib.xla_client as xc
9
+
10
+ Device = xc.Device
11
+
12
+ from .types import History
13
+
14
+
15
+ def slice_history(
16
+ history: History,
17
+ start: int,
18
+ length: int,
19
+ fields: Optional[Union[Sequence[str], str]] = None,
20
+ ) -> History:
21
+ """Return a temporal slice [start : start+length] for selected fields.
22
+
23
+ Args:
24
+ history (History): Mapping field -> array of shape (B, T, ...).
25
+ start (int): Starting timestep (0-based).
26
+ length (int): Number of timesteps to include (> 0).
27
+ fields (Optional[Sequence[str] | str]): One or more field names to slice.
28
+ If a single string is provided, only that field is sliced.
29
+ If a list or tuple is provided, all listed fields are sliced.
30
+ If None, all fields in `history` are sliced.
31
+
32
+ Returns:
33
+ History: Mapping of sliced arrays with shape (B, length, ...).
34
+
35
+ Raises:
36
+ ValueError: If `length` <= 0, `start` out of bounds, or slice exceeds T.
37
+ KeyError: If `fields` is not present in `history`.
38
+ TypeError: If `history` is empty or contains tensors with rank < 2.
39
+
40
+ """
41
+ # Validate length and history
42
+ if length <= 0:
43
+ raise ValueError("length must be > 0")
44
+ if not history:
45
+ raise TypeError("history must be a non-empty mapping")
46
+
47
+ # Validate reference array and slice bounds
48
+ any_arr = next(iter(history.values()))
49
+ if any_arr.ndim < 2:
50
+ raise TypeError("history arrays must have rank >= 2 (B, T, ...)")
51
+ T = any_arr.shape[1]
52
+ if start < 0 or start + length > T:
53
+ raise ValueError(f"Invalid slice [{start}:{start+length}] for T={T}")
54
+
55
+ # Normalize and validate `fields`
56
+ if fields is None:
57
+ keys = list(history.keys())
58
+ elif isinstance(fields, str):
59
+ keys = [fields]
60
+ elif isinstance(fields, (list, tuple)):
61
+ if not fields:
62
+ raise ValueError("fields list is empty.")
63
+ if not all(isinstance(f, str) for f in fields):
64
+ raise TypeError("All field names must be strings.")
65
+ keys = list(fields)
66
+ else:
67
+ raise TypeError("fields must be a string, list/tuple of strings, or None")
68
+
69
+ # Check that all requested fields exist
70
+ missing = set(keys) - set(history)
71
+ if missing:
72
+ raise KeyError(f"Unknown fields: {missing}")
73
+
74
+ # Validate ranks for selected fields
75
+ for k in keys:
76
+ if history[k].ndim < 2:
77
+ raise TypeError(f"Field {k!r} must have rank >= 2 (B, T, ...)")
78
+
79
+ return {k: history[k][:, start : start + length, ...] for k in keys}
80
+
81
+
82
+ def peek_last(history: History, k: int = 1) -> History:
83
+ """Return the last `k` timesteps for all fields.
84
+
85
+ Args:
86
+ history (History): Mapping field -> array of shape (B, T, *payload).
87
+ k (int): Number of trailing timesteps to select (1 ≤ k ≤ T).
88
+
89
+ Returns:
90
+ History: Mapping field -> sliced array of shape (B, k, *payload).
91
+
92
+ Raises:
93
+ ValueError: If `k` < 1 or `k` > T for any field.
94
+ TypeError: If `history` is empty or contains tensors with rank < 2.
95
+
96
+ """
97
+ if not history:
98
+ raise TypeError("history must be a non-empty mapping")
99
+
100
+ any_arr = next(iter(history.values()))
101
+ if any_arr.ndim < 2:
102
+ raise TypeError("history arrays must have rank >= 2 (B, T, ...)")
103
+ T = any_arr.shape[1]
104
+
105
+ if k < 1 or k > T:
106
+ raise ValueError(f"k must be in [1, T]; got k={k}, T={T}")
107
+
108
+ # Use negative slicing for clarity and to keep JAX-friendly semantics.
109
+ return {k_name: v[:, -k:, ...] for k_name, v in history.items()}
110
+
111
+
112
+ def to_device(
113
+ history: History,
114
+ devices: Optional[Sequence[Device]] = None, # type: ignore[type-arg]
115
+ keys: Optional[Tuple[str, ...]] = None,
116
+ ) -> History:
117
+ """Move selected history arrays to one or more JAX devices.
118
+
119
+ This function moves JAX arrays contained in a dictionary (or subset of it)
120
+ to a target device or set of devices. If multiple devices are provided,
121
+ arrays are distributed in a simple round-robin fashion across them.
122
+
123
+ Non-array values (e.g., metadata, scalars, strings) are left unchanged.
124
+
125
+ Args:
126
+ history (History): Mapping field to array (or PyTree of arrays).
127
+ devices (Optional[Sequence[Device]]): Target devices. Defaults to first device.
128
+ keys (Optional[tuple[str, ...]]): Subset of fields to move. If None, move all.
129
+
130
+ Returns:
131
+ History: Copy of the history with selected arrays placed on the target device(s).
132
+
133
+ """
134
+ devices = devices or jax.devices()
135
+
136
+ # Select subset of keys if specified
137
+ subset = history if keys is None else {k: history[k] for k in keys if k in history}
138
+
139
+ moved = {}
140
+ for i, (k, v) in enumerate(subset.items()):
141
+ device = devices[i % len(devices)] # Round-robin device selection
142
+
143
+ # Recursively move PyTree leaves to the device
144
+ moved[k] = jax.tree_util.tree_map(
145
+ lambda x: jax.device_put(x, device) if isinstance(x, jax.Array) else x,
146
+ v,
147
+ )
148
+
149
+ # If all keys were moved, just return the moved version
150
+ if keys is None:
151
+ return moved
152
+
153
+ # Otherwise, merge moved subset back into the original dict
154
+ return {**history, **moved}
155
+
156
+
157
+ def to_host(
158
+ history: History,
159
+ keys: Optional[Tuple[str, ...]] = None,
160
+ ) -> History:
161
+ """Copy selected history arrays from device to host memory.
162
+
163
+ Recursively retrieves device arrays from JAX devices and copies them
164
+ into host (NumPy) memory. Non-array values are left unchanged.
165
+
166
+ Args:
167
+ history (History): Mapping field to array (or PyTree of arrays).
168
+ keys (Optional[tuple[str, ...]]): Subset of fields to copy. If None, all.
169
+
170
+ Returns:
171
+ History: Copy of the history with arrays stored in host (NumPy) memory.
172
+
173
+ Example:
174
+ >>> host_history = to_host(device_history)
175
+ >>> type(host_history["loss"])
176
+ <class 'numpy.ndarray'>
177
+
178
+ """
179
+ subset = history if keys is None else {k: history[k] for k in keys if k in history}
180
+
181
+ moved = {}
182
+ for k, v in subset.items():
183
+ # Recursively copy all arrays from device to host
184
+ moved[k] = jax.tree_util.tree_map(
185
+ lambda x: jax.device_get(x) if isinstance(x, jax.Array) else x,
186
+ v,
187
+ )
188
+
189
+ if keys is None:
190
+ return moved
191
+ return {**history, **moved}
goggles/media.py ADDED
@@ -0,0 +1,284 @@
1
+ """Media utilities for saving images and videos from numpy arrays."""
2
+
3
+ from typing import Literal
4
+ import numpy as np
5
+ import imageio
6
+ from pathlib import Path
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
10
+
11
+
12
+ def _to_uint8(arr: np.ndarray) -> np.ndarray:
13
+ if arr.dtype == np.uint8:
14
+ return arr
15
+ a = arr.astype(np.float32)
16
+ vmin = float(np.min(a))
17
+ vmax = float(np.max(a))
18
+ if np.isclose(vmax, vmin):
19
+ return np.zeros_like(a, dtype=np.uint8)
20
+ if vmin >= 0.0 and vmax <= 1.0:
21
+ a = a * 255.0
22
+ else:
23
+ a = (a - vmin) / (vmax - vmin) * 255.0
24
+ return np.clip(a, 0, 255).astype(np.uint8)
25
+
26
+
27
+ def _normalize_frames(frames: np.ndarray):
28
+ arr = np.asarray(frames)
29
+ if arr.ndim < 3:
30
+ raise ValueError("Expected shape (T, *image_shape[, C]).")
31
+ # channels
32
+ if arr.ndim == 3: # (T,H,W)
33
+ C = 1
34
+ else: # (T,H,W,C)
35
+ C = arr.shape[-1]
36
+ if C not in (1, 3):
37
+ raise ValueError(f"Last dimension must be 1 or 3 channels, got {C}.")
38
+ if arr.ndim >= 4 and C == 1:
39
+ arr = arr[..., 0] # -> (T,H,W)
40
+ arr_u8 = _to_uint8(arr)
41
+ mode = "L" if (arr_u8.ndim == 3 or C == 1) else "RGB"
42
+ return arr_u8, mode
43
+
44
+
45
+ def _ensure_even_hw(u8: np.ndarray) -> np.ndarray:
46
+ if u8.ndim == 3: # (T,H,W)
47
+ T, H, W = u8.shape
48
+ pad_h, pad_w = H % 2, W % 2
49
+ if pad_h or pad_w:
50
+ out = np.zeros((T, H + pad_h, W + pad_w), dtype=np.uint8)
51
+ out[:, :H, :W] = u8
52
+ return out
53
+ return u8
54
+ else: # (T,H,W,3)
55
+ T, H, W, C = u8.shape
56
+ pad_h, pad_w = H % 2, W % 2
57
+ if pad_h or pad_w:
58
+ out = np.zeros((T, H + pad_h, W + pad_w, C), dtype=np.uint8)
59
+ out[:, :H, :W, :] = u8
60
+ return out
61
+ return u8
62
+
63
+
64
+ def save_numpy_gif(
65
+ frames: np.ndarray, out_path: str, fps: int = 10, loop: int = 0
66
+ ) -> None:
67
+ """Save a NumPy clip to GIF using imageio.
68
+
69
+ Args:
70
+ frames (np.ndarray): Input clip as a NumPy array of shape
71
+ (T, H, W) or (T, H, W, C) where C is 1 or 3.
72
+ out_path (str): Output file path.
73
+ fps (int): Frames per second.
74
+ loop (int): Number of times the GIF should loop (0 = infinite).
75
+
76
+ """
77
+ arr_u8, _ = _normalize_frames(frames)
78
+ imgs = [arr_u8[i] for i in range(arr_u8.shape[0])]
79
+ imageio.mimsave(
80
+ out_path,
81
+ imgs,
82
+ format="GIF",
83
+ duration=1.0 / float(fps),
84
+ loop=loop,
85
+ palettesize=256,
86
+ subrectangles=True,
87
+ )
88
+
89
+
90
+ def save_numpy_mp4(
91
+ frames: np.ndarray,
92
+ out_path: Path,
93
+ fps: int = 30,
94
+ codec: str = "libx264",
95
+ pix_fmt: str = "yuv420p",
96
+ bitrate: str | None = None,
97
+ crf: int | None = 18,
98
+ convert_gray_to_rgb: bool = True,
99
+ preset: str | None = "medium",
100
+ ) -> None:
101
+ """Save a NumPy clip to MP4 using imageio-ffmpeg.
102
+
103
+ Args:
104
+ frames (np.ndarray): Input clip as a NumPy array of shape
105
+ (T, H, W) or (T, H, W, C) where C is 1 or 3.
106
+ out_path (str): Output file path.
107
+ fps (int): Frames per second.
108
+ codec (str): Video codec to use.
109
+ pix_fmt (str): Pixel format for ffmpeg.
110
+ bitrate (str | None): Bitrate string for ffmpeg (e.g. "4M").
111
+ crf (int | None): Constant Rate Factor for quality control.
112
+ convert_gray_to_rgb (bool): Whether to convert grayscale to RGB.
113
+ preset (str | None): ffmpeg preset for speed/quality tradeoff.
114
+
115
+ """
116
+ arr_u8, mode = _normalize_frames(frames)
117
+
118
+ # Prefer RGB for broad player compatibility
119
+ if mode == "L" and convert_gray_to_rgb:
120
+ arr_u8 = np.stack([arr_u8] * 3, axis=-1) # (T,H,W) -> (T,H,W,3)
121
+
122
+ # Ensure even dims for yuv420p
123
+ arr_u8 = _ensure_even_hw(arr_u8)
124
+
125
+ # Build writer kwargs
126
+ writer_kwargs = {
127
+ "fps": fps,
128
+ "codec": codec,
129
+ "macro_block_size": None,
130
+ "format": "FFMPEG",
131
+ }
132
+ if bitrate is not None:
133
+ writer_kwargs["bitrate"] = bitrate
134
+
135
+ ffmpeg_params = []
136
+ if crf is not None:
137
+ ffmpeg_params += ["-crf", str(crf)]
138
+ if preset is not None:
139
+ ffmpeg_params += ["-preset", preset]
140
+ if pix_fmt is not None:
141
+ ffmpeg_params += ["-pix_fmt", pix_fmt]
142
+ if ffmpeg_params:
143
+ writer_kwargs["ffmpeg_params"] = ffmpeg_params
144
+
145
+ with imageio.get_writer(out_path, **writer_kwargs) as writer:
146
+ if arr_u8.ndim == 3:
147
+ for i in range(arr_u8.shape[0]):
148
+ writer.append_data(arr_u8[i])
149
+ else:
150
+ for i in range(arr_u8.shape[0]):
151
+ writer.append_data(arr_u8[i])
152
+
153
+
154
+ def save_numpy_image(image: np.ndarray, out_path: str, format: str) -> None:
155
+ """Save a NumPy image to file using imageio.
156
+
157
+ Args:
158
+ image (np.ndarray): Input image as a NumPy array of shape
159
+ (H, W) or (H, W, C) where C is 1 or 3.
160
+ out_path (str): Output file path.
161
+ format (str): Image format (e.g., 'png', 'jpg', 'jpeg').
162
+
163
+ """
164
+ arr_u8, _ = _normalize_frames(image[np.newaxis, ...])
165
+ imageio.imwrite(out_path, arr_u8[0], format=format)
166
+
167
+
168
+ def save_numpy_vector_field_visualization(
169
+ vector_field: np.ndarray,
170
+ dir: Path,
171
+ name: str,
172
+ mode: Literal["vorticity", "magnitude"] = "magnitude",
173
+ arrow_stride: int = 8,
174
+ dpi: int = 300,
175
+ add_colorbar: bool = True,
176
+ ) -> None:
177
+ """Save a 2D vector field visualization as a PNG image.
178
+
179
+ Args:
180
+ vector_field (np.ndarray): Input vector field of shape (H, W, 2).
181
+ dir (Path): Output directory path.
182
+ name (str): Base name for the output PNG file (without extension).
183
+ mode (Literal["vorticity", "magnitude"]): Visualization mode.
184
+ arrow_stride (int): Stride for downsampling arrows (every Nth point).
185
+ dpi (int): Resolution of the output image.
186
+ add_colorbar (bool): Whether to include a colorbar.
187
+
188
+ """
189
+ # Store original backend to restore later
190
+ original_backend = matplotlib.get_backend()
191
+ matplotlib.use("Agg")
192
+
193
+ try:
194
+ H, W, _ = vector_field.shape
195
+
196
+ # Create figure that matches pixel aspect; keep margins zero by default
197
+ fig, ax = plt.subplots(figsize=(W / 50, H / 50), dpi=dpi)
198
+ ax.set_aspect("equal")
199
+
200
+ # Compute scalar field and arrow color based on the selected mode
201
+ if mode == "magnitude":
202
+ scalar_field = np.linalg.norm(vector_field, axis=-1)
203
+ cmap = plt.cm.viridis
204
+ arrow_color = "white"
205
+ elif mode == "vorticity":
206
+ # Compute vorticity (curl) of the vector field: dVx/dy - dVy/dx
207
+ dy = np.gradient(vector_field[..., 0], axis=0)
208
+ dx = np.gradient(vector_field[..., 1], axis=1)
209
+ scalar_field = dx - dy
210
+ cmap = plt.cm.RdBu_r
211
+ arrow_color = "black"
212
+ else:
213
+ raise ValueError("mode must be 'magnitude' or 'vorticity'")
214
+
215
+ # Display scalar field as background
216
+ im = ax.imshow(
217
+ scalar_field,
218
+ cmap=cmap,
219
+ origin="lower",
220
+ extent=[0, W, 0, H],
221
+ interpolation="bilinear",
222
+ )
223
+
224
+ # Arrow grid
225
+ y_coords, x_coords = np.mgrid[0:H:arrow_stride, 0:W:arrow_stride]
226
+ u_sampled = vector_field[::arrow_stride, ::arrow_stride, 0]
227
+ v_sampled = vector_field[::arrow_stride, ::arrow_stride, 1]
228
+
229
+ # Plot arrows
230
+ ax.quiver(
231
+ x_coords,
232
+ y_coords,
233
+ u_sampled,
234
+ v_sampled,
235
+ color=arrow_color,
236
+ alpha=0.9,
237
+ scale_units="xy",
238
+ scale=1,
239
+ width=0.002,
240
+ headwidth=4,
241
+ headlength=5,
242
+ headaxislength=4.5,
243
+ linewidth=0.5,
244
+ edgecolor="none",
245
+ )
246
+
247
+ # Remove axes and ticks
248
+ ax.set_xticks([])
249
+ ax.set_yticks([])
250
+ ax.set_frame_on(False)
251
+
252
+ # (1) Colorbar exactly same height as the axes
253
+ if add_colorbar:
254
+ divider = make_axes_locatable(ax)
255
+ # size can be tweaked; "3%" is a nice thin bar, pad is the gap
256
+ cax = divider.append_axes("right", size="3%", pad=0.02)
257
+ cb = fig.colorbar(im, cax=cax)
258
+ cb.ax.tick_params(length=2)
259
+
260
+ # Leave a tiny margin so ticks/labels aren't clipped
261
+ fig.subplots_adjust(left=0.0, right=0.98, bottom=0.0, top=1.0)
262
+ bbox_setting = "tight"
263
+ pad_setting = 0.02
264
+ else:
265
+ # (2) No colorbar: strip all outer boundaries/margins
266
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
267
+ bbox_setting = "tight"
268
+ pad_setting = 0.0
269
+
270
+ # Save
271
+ dir.mkdir(parents=True, exist_ok=True)
272
+ plt.savefig(
273
+ str(dir / f"{name}.png"),
274
+ dpi=dpi,
275
+ bbox_inches=bbox_setting,
276
+ pad_inches=pad_setting,
277
+ facecolor="white",
278
+ edgecolor="none",
279
+ )
280
+ plt.close(fig)
281
+
282
+ finally:
283
+ # Restore original matplotlib backend
284
+ matplotlib.use(original_backend)
goggles/shutdown.py ADDED
@@ -0,0 +1,70 @@
1
+ """Simple util for graceful shutdowns in Python applications."""
2
+
3
+ import signal
4
+ from typing import Optional
5
+
6
+
7
+ class GracefulShutdown:
8
+ """A context manager for graceful shutdowns.
9
+
10
+ Example:
11
+ >>> with GracefulShutdown(exit_message="Shutting down gracefully...") as gs:
12
+ ... while not gs.stop:
13
+ ... # Main application logic here, runs until interrupted
14
+ ... # by SIGINT or SIGTERM.
15
+ ... pass
16
+ ... print("Cleanup and exit.")
17
+
18
+ """
19
+
20
+ stop = False
21
+
22
+ def __init__(
23
+ self,
24
+ exit_message: Optional[str] = None,
25
+ ):
26
+ """Initialize the GracefulShutdown context manager.
27
+
28
+ Args:
29
+ exit_message (str): The message to log upon shutdown.
30
+
31
+ """
32
+ from . import get_logger
33
+
34
+ self.logger = get_logger("goggles.shutdown")
35
+ self.exit_message = exit_message
36
+ # placeholders for original handlers
37
+ self._orig_sigint = None
38
+ self._orig_sigterm = None
39
+
40
+ def __enter__(self):
41
+ """Register the signal handlers."""
42
+ # save existing handlers
43
+ self._orig_sigint = signal.getsignal(signal.SIGINT)
44
+ self._orig_sigterm = signal.getsignal(signal.SIGTERM)
45
+
46
+ def handle_signal(signum, frame):
47
+ self.stop = True
48
+ if self.exit_message:
49
+ self.logger.info(self.exit_message)
50
+
51
+ # register for both SIGINT and SIGTERM
52
+ signal.signal(signal.SIGINT, handle_signal)
53
+ signal.signal(signal.SIGTERM, handle_signal)
54
+
55
+ return self
56
+
57
+ def __exit__(self, exc_type, exc_value, traceback):
58
+ """Unregister the signal handlers, restoring originals.
59
+
60
+ Args:
61
+ exc_type: Exception type if any.
62
+ exc_value: Exception value if any.
63
+ traceback: Traceback if any.
64
+
65
+ """
66
+ # restore original handlers
67
+ if self._orig_sigint is not None:
68
+ signal.signal(signal.SIGINT, self._orig_sigint)
69
+ if self._orig_sigterm is not None:
70
+ signal.signal(signal.SIGTERM, self._orig_sigterm)
goggles/types.py ADDED
@@ -0,0 +1,79 @@
1
+ """Types used in Goggles."""
2
+
3
+ import numpy as np
4
+ from typing import Dict, Literal, Any, Optional
5
+ from dataclasses import dataclass
6
+ from typing import TypeAlias
7
+
8
+ Kind = Literal["log", "metric", "image", "video", "artifact"]
9
+
10
+ Metrics = Dict[str, float | int]
11
+ Image: TypeAlias = np.ndarray
12
+ Video: TypeAlias = np.ndarray
13
+ Vector: TypeAlias = np.ndarray
14
+ VectorField: TypeAlias = np.ndarray
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Event:
19
+ """Structured event routed through the EventBus.
20
+
21
+ Args:
22
+ kind (Kind): Kind of event ("log", "metric", "image", "artifact").
23
+ scope (str): Scope of the event ("global" or "run").
24
+ payload (Any): Event payload.
25
+ filepath (str): File path of the caller emitting the event.
26
+ lineno (int): Line number of the caller emitting the event.
27
+ level (Optional[int]): Optional log level for "log" events.
28
+ step (Optional[int]): Optional global step index.
29
+ time (Optional[float]): Optional global timestamp.
30
+ extra (Optional[dict[str, Any]]): Optional extra metadata.
31
+
32
+ """
33
+
34
+ kind: Kind
35
+ scope: str
36
+ payload: Any
37
+ filepath: str
38
+ lineno: int
39
+ level: Optional[int] = None
40
+ step: Optional[int] = None
41
+ time: Optional[float] = None
42
+ extra: Optional[Dict[str, Any]] = None
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ """Convert Event to dictionary."""
46
+ result = {
47
+ "kind": self.kind,
48
+ "scope": self.scope,
49
+ "payload": self.payload,
50
+ "filepath": self.filepath,
51
+ "lineno": self.lineno,
52
+ }
53
+
54
+ # Only include optional fields if they are not None
55
+ if self.level is not None:
56
+ result["level"] = self.level
57
+ if self.step is not None:
58
+ result["step"] = self.step
59
+ if self.time is not None:
60
+ result["time"] = self.time
61
+ if self.extra is not None:
62
+ result["extra"] = self.extra
63
+
64
+ return result
65
+
66
+ @classmethod
67
+ def from_dict(cls, data: Dict[str, Any]) -> "Event":
68
+ """Create Event from dictionary."""
69
+ return cls(
70
+ kind=data["kind"],
71
+ scope=data["scope"],
72
+ payload=data["payload"],
73
+ filepath=data["filepath"],
74
+ lineno=data["lineno"],
75
+ level=data.get("level"),
76
+ step=data.get("step"),
77
+ time=data.get("time"),
78
+ extra=data.get("extra"),
79
+ )