robo-goggles 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goggles/__init__.py +786 -0
- goggles/_core/integrations/__init__.py +26 -0
- goggles/_core/integrations/console.py +111 -0
- goggles/_core/integrations/storage.py +382 -0
- goggles/_core/integrations/wandb.py +253 -0
- goggles/_core/logger.py +602 -0
- goggles/_core/routing.py +127 -0
- goggles/config.py +68 -0
- goggles/decorators.py +81 -0
- goggles/history/__init__.py +39 -0
- goggles/history/buffer.py +185 -0
- goggles/history/spec.py +143 -0
- goggles/history/types.py +9 -0
- goggles/history/utils.py +191 -0
- goggles/media.py +284 -0
- goggles/shutdown.py +70 -0
- goggles/types.py +79 -0
- robo_goggles-0.1.0.dist-info/METADATA +600 -0
- robo_goggles-0.1.0.dist-info/RECORD +22 -0
- robo_goggles-0.1.0.dist-info/WHEEL +5 -0
- robo_goggles-0.1.0.dist-info/licenses/LICENSE +21 -0
- robo_goggles-0.1.0.dist-info/top_level.txt +1 -0
goggles/history/utils.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Utility functions for history slicing and inspection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Sequence, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jaxlib.xla_client as xc
|
|
9
|
+
|
|
10
|
+
Device = xc.Device
|
|
11
|
+
|
|
12
|
+
from .types import History
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def slice_history(
|
|
16
|
+
history: History,
|
|
17
|
+
start: int,
|
|
18
|
+
length: int,
|
|
19
|
+
fields: Optional[Union[Sequence[str], str]] = None,
|
|
20
|
+
) -> History:
|
|
21
|
+
"""Return a temporal slice [start : start+length] for selected fields.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
history (History): Mapping field -> array of shape (B, T, ...).
|
|
25
|
+
start (int): Starting timestep (0-based).
|
|
26
|
+
length (int): Number of timesteps to include (> 0).
|
|
27
|
+
fields (Optional[Sequence[str] | str]): One or more field names to slice.
|
|
28
|
+
If a single string is provided, only that field is sliced.
|
|
29
|
+
If a list or tuple is provided, all listed fields are sliced.
|
|
30
|
+
If None, all fields in `history` are sliced.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
History: Mapping of sliced arrays with shape (B, length, ...).
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If `length` <= 0, `start` out of bounds, or slice exceeds T.
|
|
37
|
+
KeyError: If `fields` is not present in `history`.
|
|
38
|
+
TypeError: If `history` is empty or contains tensors with rank < 2.
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
# Validate length and history
|
|
42
|
+
if length <= 0:
|
|
43
|
+
raise ValueError("length must be > 0")
|
|
44
|
+
if not history:
|
|
45
|
+
raise TypeError("history must be a non-empty mapping")
|
|
46
|
+
|
|
47
|
+
# Validate reference array and slice bounds
|
|
48
|
+
any_arr = next(iter(history.values()))
|
|
49
|
+
if any_arr.ndim < 2:
|
|
50
|
+
raise TypeError("history arrays must have rank >= 2 (B, T, ...)")
|
|
51
|
+
T = any_arr.shape[1]
|
|
52
|
+
if start < 0 or start + length > T:
|
|
53
|
+
raise ValueError(f"Invalid slice [{start}:{start+length}] for T={T}")
|
|
54
|
+
|
|
55
|
+
# Normalize and validate `fields`
|
|
56
|
+
if fields is None:
|
|
57
|
+
keys = list(history.keys())
|
|
58
|
+
elif isinstance(fields, str):
|
|
59
|
+
keys = [fields]
|
|
60
|
+
elif isinstance(fields, (list, tuple)):
|
|
61
|
+
if not fields:
|
|
62
|
+
raise ValueError("fields list is empty.")
|
|
63
|
+
if not all(isinstance(f, str) for f in fields):
|
|
64
|
+
raise TypeError("All field names must be strings.")
|
|
65
|
+
keys = list(fields)
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("fields must be a string, list/tuple of strings, or None")
|
|
68
|
+
|
|
69
|
+
# Check that all requested fields exist
|
|
70
|
+
missing = set(keys) - set(history)
|
|
71
|
+
if missing:
|
|
72
|
+
raise KeyError(f"Unknown fields: {missing}")
|
|
73
|
+
|
|
74
|
+
# Validate ranks for selected fields
|
|
75
|
+
for k in keys:
|
|
76
|
+
if history[k].ndim < 2:
|
|
77
|
+
raise TypeError(f"Field {k!r} must have rank >= 2 (B, T, ...)")
|
|
78
|
+
|
|
79
|
+
return {k: history[k][:, start : start + length, ...] for k in keys}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def peek_last(history: History, k: int = 1) -> History:
|
|
83
|
+
"""Return the last `k` timesteps for all fields.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
history (History): Mapping field -> array of shape (B, T, *payload).
|
|
87
|
+
k (int): Number of trailing timesteps to select (1 ≤ k ≤ T).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
History: Mapping field -> sliced array of shape (B, k, *payload).
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ValueError: If `k` < 1 or `k` > T for any field.
|
|
94
|
+
TypeError: If `history` is empty or contains tensors with rank < 2.
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
if not history:
|
|
98
|
+
raise TypeError("history must be a non-empty mapping")
|
|
99
|
+
|
|
100
|
+
any_arr = next(iter(history.values()))
|
|
101
|
+
if any_arr.ndim < 2:
|
|
102
|
+
raise TypeError("history arrays must have rank >= 2 (B, T, ...)")
|
|
103
|
+
T = any_arr.shape[1]
|
|
104
|
+
|
|
105
|
+
if k < 1 or k > T:
|
|
106
|
+
raise ValueError(f"k must be in [1, T]; got k={k}, T={T}")
|
|
107
|
+
|
|
108
|
+
# Use negative slicing for clarity and to keep JAX-friendly semantics.
|
|
109
|
+
return {k_name: v[:, -k:, ...] for k_name, v in history.items()}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def to_device(
|
|
113
|
+
history: History,
|
|
114
|
+
devices: Optional[Sequence[Device]] = None, # type: ignore[type-arg]
|
|
115
|
+
keys: Optional[Tuple[str, ...]] = None,
|
|
116
|
+
) -> History:
|
|
117
|
+
"""Move selected history arrays to one or more JAX devices.
|
|
118
|
+
|
|
119
|
+
This function moves JAX arrays contained in a dictionary (or subset of it)
|
|
120
|
+
to a target device or set of devices. If multiple devices are provided,
|
|
121
|
+
arrays are distributed in a simple round-robin fashion across them.
|
|
122
|
+
|
|
123
|
+
Non-array values (e.g., metadata, scalars, strings) are left unchanged.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
history (History): Mapping field to array (or PyTree of arrays).
|
|
127
|
+
devices (Optional[Sequence[Device]]): Target devices. Defaults to first device.
|
|
128
|
+
keys (Optional[tuple[str, ...]]): Subset of fields to move. If None, move all.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
History: Copy of the history with selected arrays placed on the target device(s).
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
devices = devices or jax.devices()
|
|
135
|
+
|
|
136
|
+
# Select subset of keys if specified
|
|
137
|
+
subset = history if keys is None else {k: history[k] for k in keys if k in history}
|
|
138
|
+
|
|
139
|
+
moved = {}
|
|
140
|
+
for i, (k, v) in enumerate(subset.items()):
|
|
141
|
+
device = devices[i % len(devices)] # Round-robin device selection
|
|
142
|
+
|
|
143
|
+
# Recursively move PyTree leaves to the device
|
|
144
|
+
moved[k] = jax.tree_util.tree_map(
|
|
145
|
+
lambda x: jax.device_put(x, device) if isinstance(x, jax.Array) else x,
|
|
146
|
+
v,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# If all keys were moved, just return the moved version
|
|
150
|
+
if keys is None:
|
|
151
|
+
return moved
|
|
152
|
+
|
|
153
|
+
# Otherwise, merge moved subset back into the original dict
|
|
154
|
+
return {**history, **moved}
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def to_host(
|
|
158
|
+
history: History,
|
|
159
|
+
keys: Optional[Tuple[str, ...]] = None,
|
|
160
|
+
) -> History:
|
|
161
|
+
"""Copy selected history arrays from device to host memory.
|
|
162
|
+
|
|
163
|
+
Recursively retrieves device arrays from JAX devices and copies them
|
|
164
|
+
into host (NumPy) memory. Non-array values are left unchanged.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
history (History): Mapping field to array (or PyTree of arrays).
|
|
168
|
+
keys (Optional[tuple[str, ...]]): Subset of fields to copy. If None, all.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
History: Copy of the history with arrays stored in host (NumPy) memory.
|
|
172
|
+
|
|
173
|
+
Example:
|
|
174
|
+
>>> host_history = to_host(device_history)
|
|
175
|
+
>>> type(host_history["loss"])
|
|
176
|
+
<class 'numpy.ndarray'>
|
|
177
|
+
|
|
178
|
+
"""
|
|
179
|
+
subset = history if keys is None else {k: history[k] for k in keys if k in history}
|
|
180
|
+
|
|
181
|
+
moved = {}
|
|
182
|
+
for k, v in subset.items():
|
|
183
|
+
# Recursively copy all arrays from device to host
|
|
184
|
+
moved[k] = jax.tree_util.tree_map(
|
|
185
|
+
lambda x: jax.device_get(x) if isinstance(x, jax.Array) else x,
|
|
186
|
+
v,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if keys is None:
|
|
190
|
+
return moved
|
|
191
|
+
return {**history, **moved}
|
goggles/media.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""Media utilities for saving images and videos from numpy arrays."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
import numpy as np
|
|
5
|
+
import imageio
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import matplotlib
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _to_uint8(arr: np.ndarray) -> np.ndarray:
|
|
13
|
+
if arr.dtype == np.uint8:
|
|
14
|
+
return arr
|
|
15
|
+
a = arr.astype(np.float32)
|
|
16
|
+
vmin = float(np.min(a))
|
|
17
|
+
vmax = float(np.max(a))
|
|
18
|
+
if np.isclose(vmax, vmin):
|
|
19
|
+
return np.zeros_like(a, dtype=np.uint8)
|
|
20
|
+
if vmin >= 0.0 and vmax <= 1.0:
|
|
21
|
+
a = a * 255.0
|
|
22
|
+
else:
|
|
23
|
+
a = (a - vmin) / (vmax - vmin) * 255.0
|
|
24
|
+
return np.clip(a, 0, 255).astype(np.uint8)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _normalize_frames(frames: np.ndarray):
|
|
28
|
+
arr = np.asarray(frames)
|
|
29
|
+
if arr.ndim < 3:
|
|
30
|
+
raise ValueError("Expected shape (T, *image_shape[, C]).")
|
|
31
|
+
# channels
|
|
32
|
+
if arr.ndim == 3: # (T,H,W)
|
|
33
|
+
C = 1
|
|
34
|
+
else: # (T,H,W,C)
|
|
35
|
+
C = arr.shape[-1]
|
|
36
|
+
if C not in (1, 3):
|
|
37
|
+
raise ValueError(f"Last dimension must be 1 or 3 channels, got {C}.")
|
|
38
|
+
if arr.ndim >= 4 and C == 1:
|
|
39
|
+
arr = arr[..., 0] # -> (T,H,W)
|
|
40
|
+
arr_u8 = _to_uint8(arr)
|
|
41
|
+
mode = "L" if (arr_u8.ndim == 3 or C == 1) else "RGB"
|
|
42
|
+
return arr_u8, mode
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _ensure_even_hw(u8: np.ndarray) -> np.ndarray:
|
|
46
|
+
if u8.ndim == 3: # (T,H,W)
|
|
47
|
+
T, H, W = u8.shape
|
|
48
|
+
pad_h, pad_w = H % 2, W % 2
|
|
49
|
+
if pad_h or pad_w:
|
|
50
|
+
out = np.zeros((T, H + pad_h, W + pad_w), dtype=np.uint8)
|
|
51
|
+
out[:, :H, :W] = u8
|
|
52
|
+
return out
|
|
53
|
+
return u8
|
|
54
|
+
else: # (T,H,W,3)
|
|
55
|
+
T, H, W, C = u8.shape
|
|
56
|
+
pad_h, pad_w = H % 2, W % 2
|
|
57
|
+
if pad_h or pad_w:
|
|
58
|
+
out = np.zeros((T, H + pad_h, W + pad_w, C), dtype=np.uint8)
|
|
59
|
+
out[:, :H, :W, :] = u8
|
|
60
|
+
return out
|
|
61
|
+
return u8
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def save_numpy_gif(
|
|
65
|
+
frames: np.ndarray, out_path: str, fps: int = 10, loop: int = 0
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Save a NumPy clip to GIF using imageio.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
frames (np.ndarray): Input clip as a NumPy array of shape
|
|
71
|
+
(T, H, W) or (T, H, W, C) where C is 1 or 3.
|
|
72
|
+
out_path (str): Output file path.
|
|
73
|
+
fps (int): Frames per second.
|
|
74
|
+
loop (int): Number of times the GIF should loop (0 = infinite).
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
arr_u8, _ = _normalize_frames(frames)
|
|
78
|
+
imgs = [arr_u8[i] for i in range(arr_u8.shape[0])]
|
|
79
|
+
imageio.mimsave(
|
|
80
|
+
out_path,
|
|
81
|
+
imgs,
|
|
82
|
+
format="GIF",
|
|
83
|
+
duration=1.0 / float(fps),
|
|
84
|
+
loop=loop,
|
|
85
|
+
palettesize=256,
|
|
86
|
+
subrectangles=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def save_numpy_mp4(
|
|
91
|
+
frames: np.ndarray,
|
|
92
|
+
out_path: Path,
|
|
93
|
+
fps: int = 30,
|
|
94
|
+
codec: str = "libx264",
|
|
95
|
+
pix_fmt: str = "yuv420p",
|
|
96
|
+
bitrate: str | None = None,
|
|
97
|
+
crf: int | None = 18,
|
|
98
|
+
convert_gray_to_rgb: bool = True,
|
|
99
|
+
preset: str | None = "medium",
|
|
100
|
+
) -> None:
|
|
101
|
+
"""Save a NumPy clip to MP4 using imageio-ffmpeg.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
frames (np.ndarray): Input clip as a NumPy array of shape
|
|
105
|
+
(T, H, W) or (T, H, W, C) where C is 1 or 3.
|
|
106
|
+
out_path (str): Output file path.
|
|
107
|
+
fps (int): Frames per second.
|
|
108
|
+
codec (str): Video codec to use.
|
|
109
|
+
pix_fmt (str): Pixel format for ffmpeg.
|
|
110
|
+
bitrate (str | None): Bitrate string for ffmpeg (e.g. "4M").
|
|
111
|
+
crf (int | None): Constant Rate Factor for quality control.
|
|
112
|
+
convert_gray_to_rgb (bool): Whether to convert grayscale to RGB.
|
|
113
|
+
preset (str | None): ffmpeg preset for speed/quality tradeoff.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
arr_u8, mode = _normalize_frames(frames)
|
|
117
|
+
|
|
118
|
+
# Prefer RGB for broad player compatibility
|
|
119
|
+
if mode == "L" and convert_gray_to_rgb:
|
|
120
|
+
arr_u8 = np.stack([arr_u8] * 3, axis=-1) # (T,H,W) -> (T,H,W,3)
|
|
121
|
+
|
|
122
|
+
# Ensure even dims for yuv420p
|
|
123
|
+
arr_u8 = _ensure_even_hw(arr_u8)
|
|
124
|
+
|
|
125
|
+
# Build writer kwargs
|
|
126
|
+
writer_kwargs = {
|
|
127
|
+
"fps": fps,
|
|
128
|
+
"codec": codec,
|
|
129
|
+
"macro_block_size": None,
|
|
130
|
+
"format": "FFMPEG",
|
|
131
|
+
}
|
|
132
|
+
if bitrate is not None:
|
|
133
|
+
writer_kwargs["bitrate"] = bitrate
|
|
134
|
+
|
|
135
|
+
ffmpeg_params = []
|
|
136
|
+
if crf is not None:
|
|
137
|
+
ffmpeg_params += ["-crf", str(crf)]
|
|
138
|
+
if preset is not None:
|
|
139
|
+
ffmpeg_params += ["-preset", preset]
|
|
140
|
+
if pix_fmt is not None:
|
|
141
|
+
ffmpeg_params += ["-pix_fmt", pix_fmt]
|
|
142
|
+
if ffmpeg_params:
|
|
143
|
+
writer_kwargs["ffmpeg_params"] = ffmpeg_params
|
|
144
|
+
|
|
145
|
+
with imageio.get_writer(out_path, **writer_kwargs) as writer:
|
|
146
|
+
if arr_u8.ndim == 3:
|
|
147
|
+
for i in range(arr_u8.shape[0]):
|
|
148
|
+
writer.append_data(arr_u8[i])
|
|
149
|
+
else:
|
|
150
|
+
for i in range(arr_u8.shape[0]):
|
|
151
|
+
writer.append_data(arr_u8[i])
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def save_numpy_image(image: np.ndarray, out_path: str, format: str) -> None:
|
|
155
|
+
"""Save a NumPy image to file using imageio.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
image (np.ndarray): Input image as a NumPy array of shape
|
|
159
|
+
(H, W) or (H, W, C) where C is 1 or 3.
|
|
160
|
+
out_path (str): Output file path.
|
|
161
|
+
format (str): Image format (e.g., 'png', 'jpg', 'jpeg').
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
arr_u8, _ = _normalize_frames(image[np.newaxis, ...])
|
|
165
|
+
imageio.imwrite(out_path, arr_u8[0], format=format)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def save_numpy_vector_field_visualization(
|
|
169
|
+
vector_field: np.ndarray,
|
|
170
|
+
dir: Path,
|
|
171
|
+
name: str,
|
|
172
|
+
mode: Literal["vorticity", "magnitude"] = "magnitude",
|
|
173
|
+
arrow_stride: int = 8,
|
|
174
|
+
dpi: int = 300,
|
|
175
|
+
add_colorbar: bool = True,
|
|
176
|
+
) -> None:
|
|
177
|
+
"""Save a 2D vector field visualization as a PNG image.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
vector_field (np.ndarray): Input vector field of shape (H, W, 2).
|
|
181
|
+
dir (Path): Output directory path.
|
|
182
|
+
name (str): Base name for the output PNG file (without extension).
|
|
183
|
+
mode (Literal["vorticity", "magnitude"]): Visualization mode.
|
|
184
|
+
arrow_stride (int): Stride for downsampling arrows (every Nth point).
|
|
185
|
+
dpi (int): Resolution of the output image.
|
|
186
|
+
add_colorbar (bool): Whether to include a colorbar.
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
# Store original backend to restore later
|
|
190
|
+
original_backend = matplotlib.get_backend()
|
|
191
|
+
matplotlib.use("Agg")
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
H, W, _ = vector_field.shape
|
|
195
|
+
|
|
196
|
+
# Create figure that matches pixel aspect; keep margins zero by default
|
|
197
|
+
fig, ax = plt.subplots(figsize=(W / 50, H / 50), dpi=dpi)
|
|
198
|
+
ax.set_aspect("equal")
|
|
199
|
+
|
|
200
|
+
# Compute scalar field and arrow color based on the selected mode
|
|
201
|
+
if mode == "magnitude":
|
|
202
|
+
scalar_field = np.linalg.norm(vector_field, axis=-1)
|
|
203
|
+
cmap = plt.cm.viridis
|
|
204
|
+
arrow_color = "white"
|
|
205
|
+
elif mode == "vorticity":
|
|
206
|
+
# Compute vorticity (curl) of the vector field: dVx/dy - dVy/dx
|
|
207
|
+
dy = np.gradient(vector_field[..., 0], axis=0)
|
|
208
|
+
dx = np.gradient(vector_field[..., 1], axis=1)
|
|
209
|
+
scalar_field = dx - dy
|
|
210
|
+
cmap = plt.cm.RdBu_r
|
|
211
|
+
arrow_color = "black"
|
|
212
|
+
else:
|
|
213
|
+
raise ValueError("mode must be 'magnitude' or 'vorticity'")
|
|
214
|
+
|
|
215
|
+
# Display scalar field as background
|
|
216
|
+
im = ax.imshow(
|
|
217
|
+
scalar_field,
|
|
218
|
+
cmap=cmap,
|
|
219
|
+
origin="lower",
|
|
220
|
+
extent=[0, W, 0, H],
|
|
221
|
+
interpolation="bilinear",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Arrow grid
|
|
225
|
+
y_coords, x_coords = np.mgrid[0:H:arrow_stride, 0:W:arrow_stride]
|
|
226
|
+
u_sampled = vector_field[::arrow_stride, ::arrow_stride, 0]
|
|
227
|
+
v_sampled = vector_field[::arrow_stride, ::arrow_stride, 1]
|
|
228
|
+
|
|
229
|
+
# Plot arrows
|
|
230
|
+
ax.quiver(
|
|
231
|
+
x_coords,
|
|
232
|
+
y_coords,
|
|
233
|
+
u_sampled,
|
|
234
|
+
v_sampled,
|
|
235
|
+
color=arrow_color,
|
|
236
|
+
alpha=0.9,
|
|
237
|
+
scale_units="xy",
|
|
238
|
+
scale=1,
|
|
239
|
+
width=0.002,
|
|
240
|
+
headwidth=4,
|
|
241
|
+
headlength=5,
|
|
242
|
+
headaxislength=4.5,
|
|
243
|
+
linewidth=0.5,
|
|
244
|
+
edgecolor="none",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Remove axes and ticks
|
|
248
|
+
ax.set_xticks([])
|
|
249
|
+
ax.set_yticks([])
|
|
250
|
+
ax.set_frame_on(False)
|
|
251
|
+
|
|
252
|
+
# (1) Colorbar exactly same height as the axes
|
|
253
|
+
if add_colorbar:
|
|
254
|
+
divider = make_axes_locatable(ax)
|
|
255
|
+
# size can be tweaked; "3%" is a nice thin bar, pad is the gap
|
|
256
|
+
cax = divider.append_axes("right", size="3%", pad=0.02)
|
|
257
|
+
cb = fig.colorbar(im, cax=cax)
|
|
258
|
+
cb.ax.tick_params(length=2)
|
|
259
|
+
|
|
260
|
+
# Leave a tiny margin so ticks/labels aren't clipped
|
|
261
|
+
fig.subplots_adjust(left=0.0, right=0.98, bottom=0.0, top=1.0)
|
|
262
|
+
bbox_setting = "tight"
|
|
263
|
+
pad_setting = 0.02
|
|
264
|
+
else:
|
|
265
|
+
# (2) No colorbar: strip all outer boundaries/margins
|
|
266
|
+
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
267
|
+
bbox_setting = "tight"
|
|
268
|
+
pad_setting = 0.0
|
|
269
|
+
|
|
270
|
+
# Save
|
|
271
|
+
dir.mkdir(parents=True, exist_ok=True)
|
|
272
|
+
plt.savefig(
|
|
273
|
+
str(dir / f"{name}.png"),
|
|
274
|
+
dpi=dpi,
|
|
275
|
+
bbox_inches=bbox_setting,
|
|
276
|
+
pad_inches=pad_setting,
|
|
277
|
+
facecolor="white",
|
|
278
|
+
edgecolor="none",
|
|
279
|
+
)
|
|
280
|
+
plt.close(fig)
|
|
281
|
+
|
|
282
|
+
finally:
|
|
283
|
+
# Restore original matplotlib backend
|
|
284
|
+
matplotlib.use(original_backend)
|
goggles/shutdown.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Simple util for graceful shutdowns in Python applications."""
|
|
2
|
+
|
|
3
|
+
import signal
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GracefulShutdown:
|
|
8
|
+
"""A context manager for graceful shutdowns.
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
>>> with GracefulShutdown(exit_message="Shutting down gracefully...") as gs:
|
|
12
|
+
... while not gs.stop:
|
|
13
|
+
... # Main application logic here, runs until interrupted
|
|
14
|
+
... # by SIGINT or SIGTERM.
|
|
15
|
+
... pass
|
|
16
|
+
... print("Cleanup and exit.")
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
stop = False
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
exit_message: Optional[str] = None,
|
|
25
|
+
):
|
|
26
|
+
"""Initialize the GracefulShutdown context manager.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
exit_message (str): The message to log upon shutdown.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
from . import get_logger
|
|
33
|
+
|
|
34
|
+
self.logger = get_logger("goggles.shutdown")
|
|
35
|
+
self.exit_message = exit_message
|
|
36
|
+
# placeholders for original handlers
|
|
37
|
+
self._orig_sigint = None
|
|
38
|
+
self._orig_sigterm = None
|
|
39
|
+
|
|
40
|
+
def __enter__(self):
|
|
41
|
+
"""Register the signal handlers."""
|
|
42
|
+
# save existing handlers
|
|
43
|
+
self._orig_sigint = signal.getsignal(signal.SIGINT)
|
|
44
|
+
self._orig_sigterm = signal.getsignal(signal.SIGTERM)
|
|
45
|
+
|
|
46
|
+
def handle_signal(signum, frame):
|
|
47
|
+
self.stop = True
|
|
48
|
+
if self.exit_message:
|
|
49
|
+
self.logger.info(self.exit_message)
|
|
50
|
+
|
|
51
|
+
# register for both SIGINT and SIGTERM
|
|
52
|
+
signal.signal(signal.SIGINT, handle_signal)
|
|
53
|
+
signal.signal(signal.SIGTERM, handle_signal)
|
|
54
|
+
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
58
|
+
"""Unregister the signal handlers, restoring originals.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
exc_type: Exception type if any.
|
|
62
|
+
exc_value: Exception value if any.
|
|
63
|
+
traceback: Traceback if any.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
# restore original handlers
|
|
67
|
+
if self._orig_sigint is not None:
|
|
68
|
+
signal.signal(signal.SIGINT, self._orig_sigint)
|
|
69
|
+
if self._orig_sigterm is not None:
|
|
70
|
+
signal.signal(signal.SIGTERM, self._orig_sigterm)
|
goggles/types.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Types used in Goggles."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Dict, Literal, Any, Optional
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TypeAlias
|
|
7
|
+
|
|
8
|
+
Kind = Literal["log", "metric", "image", "video", "artifact"]
|
|
9
|
+
|
|
10
|
+
Metrics = Dict[str, float | int]
|
|
11
|
+
Image: TypeAlias = np.ndarray
|
|
12
|
+
Video: TypeAlias = np.ndarray
|
|
13
|
+
Vector: TypeAlias = np.ndarray
|
|
14
|
+
VectorField: TypeAlias = np.ndarray
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class Event:
|
|
19
|
+
"""Structured event routed through the EventBus.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
kind (Kind): Kind of event ("log", "metric", "image", "artifact").
|
|
23
|
+
scope (str): Scope of the event ("global" or "run").
|
|
24
|
+
payload (Any): Event payload.
|
|
25
|
+
filepath (str): File path of the caller emitting the event.
|
|
26
|
+
lineno (int): Line number of the caller emitting the event.
|
|
27
|
+
level (Optional[int]): Optional log level for "log" events.
|
|
28
|
+
step (Optional[int]): Optional global step index.
|
|
29
|
+
time (Optional[float]): Optional global timestamp.
|
|
30
|
+
extra (Optional[dict[str, Any]]): Optional extra metadata.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
kind: Kind
|
|
35
|
+
scope: str
|
|
36
|
+
payload: Any
|
|
37
|
+
filepath: str
|
|
38
|
+
lineno: int
|
|
39
|
+
level: Optional[int] = None
|
|
40
|
+
step: Optional[int] = None
|
|
41
|
+
time: Optional[float] = None
|
|
42
|
+
extra: Optional[Dict[str, Any]] = None
|
|
43
|
+
|
|
44
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
45
|
+
"""Convert Event to dictionary."""
|
|
46
|
+
result = {
|
|
47
|
+
"kind": self.kind,
|
|
48
|
+
"scope": self.scope,
|
|
49
|
+
"payload": self.payload,
|
|
50
|
+
"filepath": self.filepath,
|
|
51
|
+
"lineno": self.lineno,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Only include optional fields if they are not None
|
|
55
|
+
if self.level is not None:
|
|
56
|
+
result["level"] = self.level
|
|
57
|
+
if self.step is not None:
|
|
58
|
+
result["step"] = self.step
|
|
59
|
+
if self.time is not None:
|
|
60
|
+
result["time"] = self.time
|
|
61
|
+
if self.extra is not None:
|
|
62
|
+
result["extra"] = self.extra
|
|
63
|
+
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Event":
|
|
68
|
+
"""Create Event from dictionary."""
|
|
69
|
+
return cls(
|
|
70
|
+
kind=data["kind"],
|
|
71
|
+
scope=data["scope"],
|
|
72
|
+
payload=data["payload"],
|
|
73
|
+
filepath=data["filepath"],
|
|
74
|
+
lineno=data["lineno"],
|
|
75
|
+
level=data.get("level"),
|
|
76
|
+
step=data.get("step"),
|
|
77
|
+
time=data.get("time"),
|
|
78
|
+
extra=data.get("extra"),
|
|
79
|
+
)
|