robo-goggles 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goggles/__init__.py +786 -0
- goggles/_core/integrations/__init__.py +26 -0
- goggles/_core/integrations/console.py +111 -0
- goggles/_core/integrations/storage.py +382 -0
- goggles/_core/integrations/wandb.py +253 -0
- goggles/_core/logger.py +602 -0
- goggles/_core/routing.py +127 -0
- goggles/config.py +68 -0
- goggles/decorators.py +81 -0
- goggles/history/__init__.py +39 -0
- goggles/history/buffer.py +185 -0
- goggles/history/spec.py +143 -0
- goggles/history/types.py +9 -0
- goggles/history/utils.py +191 -0
- goggles/media.py +284 -0
- goggles/shutdown.py +70 -0
- goggles/types.py +79 -0
- robo_goggles-0.1.0.dist-info/METADATA +600 -0
- robo_goggles-0.1.0.dist-info/RECORD +22 -0
- robo_goggles-0.1.0.dist-info/WHEEL +5 -0
- robo_goggles-0.1.0.dist-info/licenses/LICENSE +21 -0
- robo_goggles-0.1.0.dist-info/top_level.txt +1 -0
goggles/_core/routing.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Events routing across files, processes, and machines.
|
|
2
|
+
|
|
3
|
+
This module encapsulates the multi-machine, multi-process routing of events
|
|
4
|
+
via the EventBus class. It uses a client-server model where one process
|
|
5
|
+
acts as the host (server) and others connect to it (clients).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> bus = get_bus()
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Optional
|
|
15
|
+
import portal
|
|
16
|
+
import socket
|
|
17
|
+
import netifaces
|
|
18
|
+
|
|
19
|
+
from goggles import EventBus, Event, GOGGLES_HOST, GOGGLES_PORT
|
|
20
|
+
|
|
21
|
+
# Singleton factory ---------------------------------------------------------
|
|
22
|
+
__singleton_client: Optional[portal.Client] = None
|
|
23
|
+
__singleton_server: Optional[portal.Server] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def __i_am_host() -> bool:
|
|
27
|
+
"""Return whether this process is the goggles event bus host.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
bool: True if this process is the host, False otherwise.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
# If GOGGLES_HOST is localhost/127.0.0.1, we are always the host
|
|
34
|
+
if GOGGLES_HOST in ("localhost", "127.0.0.1", "::1"):
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
# Get all local IP addresses
|
|
38
|
+
hostname = socket.gethostname()
|
|
39
|
+
local_ips = set()
|
|
40
|
+
|
|
41
|
+
# Add hostname resolution
|
|
42
|
+
try:
|
|
43
|
+
local_ips.add(socket.gethostbyname(hostname))
|
|
44
|
+
except socket.gaierror:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
# Add all interface IPs
|
|
48
|
+
for interface in netifaces.interfaces():
|
|
49
|
+
addrs = netifaces.ifaddresses(interface)
|
|
50
|
+
for addr_family in [netifaces.AF_INET, netifaces.AF_INET6]:
|
|
51
|
+
if addr_family in addrs:
|
|
52
|
+
for addr_info in addrs[addr_family]:
|
|
53
|
+
if "addr" in addr_info:
|
|
54
|
+
local_ips.add(addr_info["addr"])
|
|
55
|
+
|
|
56
|
+
# Check if GOGGLES_HOST matches any local IP
|
|
57
|
+
return GOGGLES_HOST in local_ips
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __is_port_in_use(host: str, port: int) -> bool:
|
|
61
|
+
"""Check if a port is already in use.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
host (str): The host to check.
|
|
65
|
+
port (int): The port to check.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
bool: True if the port is in use, False otherwise.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
try:
|
|
72
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
73
|
+
sock.settimeout(1) # 1 second timeout
|
|
74
|
+
result = sock.connect_ex((host, port))
|
|
75
|
+
return result == 0 # 0 means connection successful (port in use)
|
|
76
|
+
except Exception:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_bus() -> portal.Client:
|
|
81
|
+
"""Return the process-wide EventBus singleton.
|
|
82
|
+
|
|
83
|
+
This function ensures that there is a single instance of the
|
|
84
|
+
EventBus for the entire application, even if distributed across machines.
|
|
85
|
+
|
|
86
|
+
It uses a client-server model where one process acts as the host
|
|
87
|
+
(server) and others connect to it (clients). The host is determined
|
|
88
|
+
based on the GOGGLES_HOST configuration. The methods of EventBus are
|
|
89
|
+
exposed via a portal server for remote invocation.
|
|
90
|
+
|
|
91
|
+
NOTE: It is not thread-safe. It works on multiple machines and multiple
|
|
92
|
+
processes, but it is not guaranteed to work consistently for multiple
|
|
93
|
+
threads within the same process.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
portal.Client: The singleton EventBus client.
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
if __i_am_host() and not __is_port_in_use(GOGGLES_HOST, int(GOGGLES_PORT)):
|
|
100
|
+
global __singleton_server
|
|
101
|
+
try:
|
|
102
|
+
event_bus = EventBus()
|
|
103
|
+
server = portal.Server(
|
|
104
|
+
GOGGLES_PORT, name=f"EventBus-Server@{socket.gethostname()}"
|
|
105
|
+
)
|
|
106
|
+
server.bind("attach", event_bus.attach)
|
|
107
|
+
server.bind("detach", event_bus.detach)
|
|
108
|
+
server.bind("emit", event_bus.emit)
|
|
109
|
+
server.bind("shutdown", event_bus.shutdown)
|
|
110
|
+
server.start(block=False)
|
|
111
|
+
__singleton_server = server
|
|
112
|
+
except OSError:
|
|
113
|
+
# Fallback: Server creation failed for other reasons
|
|
114
|
+
# (e.g. concurrency), no further need
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
global __singleton_client
|
|
118
|
+
if __singleton_client is None:
|
|
119
|
+
__singleton_client = portal.Client(
|
|
120
|
+
f"{GOGGLES_HOST}:{GOGGLES_PORT}",
|
|
121
|
+
name=f"EventBus-Client@{socket.gethostname()}",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return __singleton_client
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
__all__ = ["Event", "CoreEventBus", "get_bus", "Handler"]
|
goggles/config.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Utilities for loading and pretty-printing configuration files."""
|
|
2
|
+
|
|
3
|
+
from ruamel.yaml import YAML
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from rich.pretty import Pretty
|
|
6
|
+
from yaml.representer import SafeRepresenter
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PrettyConfig(dict):
|
|
10
|
+
"""Dictionary subclass with pretty-printing using ruamel.yaml."""
|
|
11
|
+
|
|
12
|
+
def __str__(self):
|
|
13
|
+
"""Return a pretty-printed string of the configuration."""
|
|
14
|
+
console = Console()
|
|
15
|
+
plain = dict(self)
|
|
16
|
+
with console.capture() as capture:
|
|
17
|
+
console.print(Pretty(plain))
|
|
18
|
+
return capture.get()
|
|
19
|
+
|
|
20
|
+
__repr__ = __str__
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_configuration(file_path: str) -> PrettyConfig:
|
|
24
|
+
"""Load YAML configuration from file and return as PrettyConfig.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
file_path (str): Path to the YAML configuration file.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
PrettyConfig: A PrettyConfig object containing the loaded configuration.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
FileNotFoundError: If the specified file does not exist.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
yaml = YAML(typ="safe", pure=True)
|
|
37
|
+
|
|
38
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
39
|
+
data = yaml.load(f) or {}
|
|
40
|
+
# Wrap the loaded dict in our PrettyConfig
|
|
41
|
+
return PrettyConfig(data)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def represent_prettyconfig(dumper, data):
|
|
45
|
+
"""Represent PrettyConfig as a YAML mapping.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dumper: The YAML dumper.
|
|
49
|
+
data: The PrettyConfig instance.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
return dumper.represent_mapping("tag:yaml.org,2002:map", dict(data))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
SafeRepresenter.add_representer(PrettyConfig, represent_prettyconfig)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def save_configuration(config: PrettyConfig, file_path: str):
|
|
59
|
+
"""Dump PrettyConfig to a YAML file.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
config (PrettyConfig): The configuration to dump.
|
|
63
|
+
file_path (str): Path to the output YAML file.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
yaml = YAML(typ="safe", pure=True)
|
|
67
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
68
|
+
yaml.dump(dict(config), f)
|
goggles/decorators.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Decorators for logging and timing function execution."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def timeit(severity=logging.INFO, name=None):
|
|
7
|
+
"""Measure the execution time of a function via decorators.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
severity (Severity): Log severity level for timing message.
|
|
11
|
+
name (str): Optional name for the timing entry.
|
|
12
|
+
If None, uses filename:function_name.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
>>> @timeit(severity=Severity.DEBUG, name="my_function_timing")
|
|
16
|
+
... def my_function():
|
|
17
|
+
... # function logic here
|
|
18
|
+
... pass
|
|
19
|
+
>>> my_function()
|
|
20
|
+
DEBUG: my_function_timing took 0.123456s
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
from goggles import GogglesLogger
|
|
24
|
+
|
|
25
|
+
def decorator(func):
|
|
26
|
+
import time
|
|
27
|
+
import os
|
|
28
|
+
from . import get_logger
|
|
29
|
+
|
|
30
|
+
logger: GogglesLogger = get_logger(
|
|
31
|
+
"goggles.decorators.timeit", with_metrics=True
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def wrapper(*args, **kwargs):
|
|
35
|
+
start = time.perf_counter()
|
|
36
|
+
result = func(*args, **kwargs)
|
|
37
|
+
duration = time.perf_counter() - start
|
|
38
|
+
filename = os.path.basename(func.__code__.co_filename)
|
|
39
|
+
fname = name or f"{filename}:{func.__name__}"
|
|
40
|
+
logger.log(severity, f"{fname} took {duration:.6f}s")
|
|
41
|
+
logger.scalar(f"timings/{fname}", duration)
|
|
42
|
+
return result
|
|
43
|
+
|
|
44
|
+
return wrapper
|
|
45
|
+
|
|
46
|
+
return decorator
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def trace_on_error():
|
|
50
|
+
"""Trace errors and log function parameters via decorators.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> @trace_on_error()
|
|
54
|
+
... def my_function(x, y):
|
|
55
|
+
... return x / y # may raise ZeroDivisionError
|
|
56
|
+
>>> my_function(10, 0)
|
|
57
|
+
ERROR: Exception in my_function: division by zero, state:
|
|
58
|
+
{'args': (10, 0), 'kwargs': {}}
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def decorator(func):
|
|
63
|
+
from . import get_logger
|
|
64
|
+
|
|
65
|
+
logger = get_logger("goggles.decorators.trace_on_error")
|
|
66
|
+
|
|
67
|
+
def wrapper(*args, **kwargs):
|
|
68
|
+
try:
|
|
69
|
+
return func(*args, **kwargs)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
# collect parameters
|
|
72
|
+
data = {"args": args, "kwargs": kwargs}
|
|
73
|
+
# if method, collect self attributes
|
|
74
|
+
if args and hasattr(args[0], "__dict__"):
|
|
75
|
+
data["self"] = args[0].__dict__
|
|
76
|
+
logger.error(f"Exception in {func.__name__}: {e}; state: {data}")
|
|
77
|
+
raise
|
|
78
|
+
|
|
79
|
+
return wrapper
|
|
80
|
+
|
|
81
|
+
return decorator
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Device-resident temporal history buffers for JAX pipelines.
|
|
2
|
+
|
|
3
|
+
This package provides typed specifications and interfaces for constructing,
|
|
4
|
+
updating, and slicing temporal histories stored on device.
|
|
5
|
+
|
|
6
|
+
Public API:
|
|
7
|
+
- HistoryFieldSpec
|
|
8
|
+
- HistorySpec
|
|
9
|
+
- create_history
|
|
10
|
+
- update_history
|
|
11
|
+
- slice_history
|
|
12
|
+
- peek_last
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import jax # noqa: F401
|
|
19
|
+
import jax.numpy as jnp # noqa: F401
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"The 'goggles.history' module requires JAX. "
|
|
23
|
+
"Install with `pip install goggles[jax]`."
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
from .spec import HistoryFieldSpec, HistorySpec
|
|
27
|
+
from .buffer import create_history, update_history
|
|
28
|
+
from .utils import slice_history, peek_last, to_device, to_host
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"HistoryFieldSpec",
|
|
32
|
+
"HistorySpec",
|
|
33
|
+
"create_history",
|
|
34
|
+
"update_history",
|
|
35
|
+
"slice_history",
|
|
36
|
+
"peek_last",
|
|
37
|
+
"to_device",
|
|
38
|
+
"to_host",
|
|
39
|
+
]
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Creation and update interfaces for device-resident history buffers."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from typing import Dict, Optional
|
|
6
|
+
from .spec import HistorySpec
|
|
7
|
+
from .types import PRNGKey, Array, History
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _apply_reset(
|
|
11
|
+
hist_row: Array,
|
|
12
|
+
new_row: Array,
|
|
13
|
+
reset: Array,
|
|
14
|
+
init_mode: str,
|
|
15
|
+
key: Optional[PRNGKey] = None,
|
|
16
|
+
) -> Array:
|
|
17
|
+
"""Shift and optionally reset a single history row.
|
|
18
|
+
|
|
19
|
+
This uses JAX-friendly control flow (lax.cond) so it can be jitted/vmap'd
|
|
20
|
+
without attempting Python-level boolean conversions of tracers.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
hist_row: Array shaped (T, *shape) for a single batch row.
|
|
24
|
+
new_row: Array shaped (1, *shape), appended at the end along time.
|
|
25
|
+
reset: Boolean scalar (0-dim) indicating whether to reset this row.
|
|
26
|
+
init_mode: One of {"zeros", "ones", "randn", "none"}.
|
|
27
|
+
key: Optional PRNGKey (shape (2,)) used when `init_mode == "randn"`.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Array with the same shape as `hist_row`, updated for this step.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If `init_mode` is unknown.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
shifted_row = jnp.concatenate([hist_row[1:], new_row], axis=0)
|
|
37
|
+
|
|
38
|
+
if init_mode == "none":
|
|
39
|
+
return shifted_row
|
|
40
|
+
|
|
41
|
+
def do_reset(_):
|
|
42
|
+
if init_mode == "zeros":
|
|
43
|
+
return jnp.zeros_like(hist_row)
|
|
44
|
+
if init_mode == "ones":
|
|
45
|
+
return jnp.ones_like(hist_row)
|
|
46
|
+
if init_mode == "randn":
|
|
47
|
+
return jax.random.normal(key, hist_row.shape, hist_row.dtype) # type: ignore
|
|
48
|
+
raise ValueError(f"Unknown init mode {init_mode!r}")
|
|
49
|
+
|
|
50
|
+
return jax.lax.cond(reset, do_reset, lambda _: shifted_row, operand=None)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def create_history(
|
|
54
|
+
spec: HistorySpec, batch_size: int, rng: Optional[PRNGKey] = None
|
|
55
|
+
) -> History:
|
|
56
|
+
"""Allocate device-resident history tensors following (B, T, *shape).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
spec (HistorySpec): Describing each field.
|
|
60
|
+
batch_size (int): Batch size (B).
|
|
61
|
+
rng (Optional[PRNGKey]): Optional PRNG key for randomized initialization
|
|
62
|
+
of the buffers (e.g., for initial values or noise).
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
dict (History): Mapping field name to array shaped (B, T, *shape).
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If batch_size <= 0 or invalid spec values.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
if batch_size <= 0:
|
|
72
|
+
raise ValueError("batch_size must be > 0")
|
|
73
|
+
|
|
74
|
+
history: History = {}
|
|
75
|
+
for name, field in spec.fields.items():
|
|
76
|
+
# Validate length
|
|
77
|
+
if field.length <= 0:
|
|
78
|
+
raise ValueError(f"Invalid history length for field '{name}'")
|
|
79
|
+
shape = (batch_size, field.length, *field.shape)
|
|
80
|
+
|
|
81
|
+
# Initialize according to policy
|
|
82
|
+
if field.init == "zeros":
|
|
83
|
+
arr = jnp.zeros(shape, field.dtype)
|
|
84
|
+
elif field.init == "ones":
|
|
85
|
+
arr = jnp.ones(shape, field.dtype)
|
|
86
|
+
elif field.init == "randn":
|
|
87
|
+
if rng is None:
|
|
88
|
+
raise ValueError(f"Field '{name}' requires rng for randn init")
|
|
89
|
+
rng, sub = jax.random.split(rng)
|
|
90
|
+
arr = jax.random.normal(sub, shape, field.dtype)
|
|
91
|
+
elif field.init == "none":
|
|
92
|
+
arr = jnp.empty(shape, field.dtype)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unknown init mode {field.init!r} for field '{name}'")
|
|
95
|
+
history[name] = arr
|
|
96
|
+
return history
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def update_history(
|
|
100
|
+
history: History,
|
|
101
|
+
new_data: Dict[str, Array],
|
|
102
|
+
reset_mask: Optional[Array] = None,
|
|
103
|
+
spec: Optional[HistorySpec] = None,
|
|
104
|
+
rng: Optional[jax.Array] = None,
|
|
105
|
+
) -> History:
|
|
106
|
+
"""Shift and append new items along the temporal axis.
|
|
107
|
+
|
|
108
|
+
Note: this function can be jitted and vmapped over batch dimensions. RNG handling:
|
|
109
|
+
if `rng` is provided, it may be either a single PRNGKey or an array of per-batch
|
|
110
|
+
keys with shape (B, 2). This lets callers supply already-sharded keys for
|
|
111
|
+
multi-device/pmap scenarios.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
history (History): Current history dict (B, T, *shape).
|
|
115
|
+
new_data (Dict[str, Array]): New entries per field, shaped (B, 1, *shape).
|
|
116
|
+
reset_mask (Optional[Array]): Optional boolean mask for resets (B,).
|
|
117
|
+
spec (Optional[HistorySpec]): Optional spec describing reset initialization.
|
|
118
|
+
rng (Optional[jax.Array]): Optional PRNG key for randomized resets.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
History: Updated history dict.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If shapes, dtypes, or append lengths are invalid.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
updated: History = {}
|
|
128
|
+
|
|
129
|
+
for name, hist in history.items():
|
|
130
|
+
if name not in new_data:
|
|
131
|
+
raise ValueError(f"Missing new data for field '{name}'")
|
|
132
|
+
new = new_data[name]
|
|
133
|
+
|
|
134
|
+
# Validate shapes/dtypes
|
|
135
|
+
if new.ndim != hist.ndim:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Dim mismatch for field '{name}': {new.shape} vs {hist.shape}"
|
|
138
|
+
)
|
|
139
|
+
if new.shape[1] != 1:
|
|
140
|
+
raise ValueError(f"Append length must be 1 for field '{name}'")
|
|
141
|
+
if new.dtype != hist.dtype:
|
|
142
|
+
raise ValueError(f"Dtype mismatch for field '{name}'")
|
|
143
|
+
|
|
144
|
+
# Determine init mode for resets
|
|
145
|
+
if spec is not None and hasattr(spec, "fields") and name in spec.fields:
|
|
146
|
+
init_mode = spec.fields[name].init
|
|
147
|
+
else:
|
|
148
|
+
init_mode = "zeros"
|
|
149
|
+
|
|
150
|
+
# Fast path: no reset handling requested.
|
|
151
|
+
if reset_mask is None:
|
|
152
|
+
updated_field = jnp.concatenate([hist[:, 1:, ...], new], axis=1)
|
|
153
|
+
updated[name] = updated_field
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
# Validate reset mask shape.
|
|
157
|
+
if reset_mask.ndim != 1 or reset_mask.shape[0] != hist.shape[0]:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Invalid reset_mask shape {reset_mask.shape}, expected (B,)"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Prepare per-batch keys when needed.
|
|
163
|
+
if init_mode == "randn":
|
|
164
|
+
if rng is None:
|
|
165
|
+
raise ValueError(f"Field '{name}' requires rng for randn reset")
|
|
166
|
+
rng_arr = jnp.asarray(rng)
|
|
167
|
+
if rng_arr.ndim == 1: # single key (2,)
|
|
168
|
+
keys = jax.random.split(rng_arr, hist.shape[0])
|
|
169
|
+
elif rng_arr.ndim == 2 and rng_arr.shape[0] == hist.shape[0]:
|
|
170
|
+
keys = rng_arr
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"rng must be a PRNGKey (shape (2,)) or per-batch keys with shape "
|
|
174
|
+
f"(B, 2); got {tuple(rng_arr.shape)}"
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
# Dummy keys; ignored unless init_mode == 'randn'.
|
|
178
|
+
keys = jnp.zeros((hist.shape[0], 2), dtype=jnp.uint32)
|
|
179
|
+
|
|
180
|
+
# Vmap over batch. Keep new with time-dim = 1 for concat in helper.
|
|
181
|
+
apply = lambda h, n, r, k: _apply_reset(h, n, r, init_mode, k)
|
|
182
|
+
updated_field = jax.vmap(apply)(hist, new[:, 0:1, ...], reset_mask, keys)
|
|
183
|
+
updated[name] = updated_field
|
|
184
|
+
|
|
185
|
+
return updated
|
goggles/history/spec.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Type specifications for device-resident history buffers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, Literal, Mapping, Tuple
|
|
7
|
+
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
|
|
10
|
+
InitMode = Literal["zeros", "ones", "randn", "none"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class HistoryFieldSpec:
|
|
15
|
+
"""Describe one temporal field stored on device.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
length (int): Number of stored timesteps for this field.
|
|
19
|
+
shape (Tuple[int, ...]): Per-timestep payload shape (no batch/time dims).
|
|
20
|
+
dtype (jnp.dtype): Array dtype.
|
|
21
|
+
init (InitMode): Initialization policy ("zeros" | "ones" | "randn" | "none").
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
length: int
|
|
26
|
+
shape: tuple[int, ...]
|
|
27
|
+
dtype: jnp.dtype = jnp.float32
|
|
28
|
+
init: InitMode = "zeros"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class HistorySpec:
|
|
33
|
+
"""Bundle multiple named history field specs.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
fields (Mapping[str, HistoryFieldSpec]): Mapping from field name to spec
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
fields: Mapping[str, HistoryFieldSpec]
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_config(cls, config: Mapping[str, Any]) -> "HistorySpec":
|
|
44
|
+
"""Construct a HistorySpec from a nested config dictionary.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
config (Mapping[str, Any]): Dict mapping field name to kwargs for
|
|
48
|
+
`HistoryFieldSpec` or to an already-built `HistoryFieldSpec`. Each
|
|
49
|
+
kwargs dict must include:
|
|
50
|
+
- "length" (int): Number of timesteps (T >= 1).
|
|
51
|
+
- "shape" (Sequence[int] | tuple[int, ...]): Per-timestep shape.
|
|
52
|
+
Optional keys:
|
|
53
|
+
- "dtype": Anything accepted by `jnp.dtype` (default float32).
|
|
54
|
+
- "init": One of {"zeros", "ones", "randn", "none"}.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
HistorySpec: Parsed specification bundle.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
TypeError: If `config` is not a mapping, or a field entry has
|
|
61
|
+
unsupported type, or shapes/dtypes have invalid types.
|
|
62
|
+
ValueError: If required keys are missing or values are invalid
|
|
63
|
+
(e.g., length < 1, negative dims, unknown init mode).
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
if not isinstance(config, Mapping):
|
|
67
|
+
raise TypeError("config must be a Mapping[str, Any].")
|
|
68
|
+
|
|
69
|
+
allowed_inits: Tuple[str, ...] = ("zeros", "ones", "randn", "none")
|
|
70
|
+
out: Dict[str, HistoryFieldSpec] = {}
|
|
71
|
+
|
|
72
|
+
for name, spec in config.items():
|
|
73
|
+
if isinstance(spec, HistoryFieldSpec):
|
|
74
|
+
# Validate basic invariants even if user provided an instance.
|
|
75
|
+
if not isinstance(spec.length, int) or spec.length < 1:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"{name!r}.length must be an int >= 1, got {spec.length}."
|
|
78
|
+
)
|
|
79
|
+
if any((not isinstance(d, int) or d < 0) for d in spec.shape):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"{name!r}.shape must be a tuple of non-negative ints, "
|
|
82
|
+
f"got {spec.shape}."
|
|
83
|
+
)
|
|
84
|
+
if spec.init not in allowed_inits:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"{name!r}.init must be one of {allowed_inits}, got {spec.init}."
|
|
87
|
+
)
|
|
88
|
+
out[name] = spec
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if not isinstance(spec, Mapping):
|
|
92
|
+
raise TypeError(
|
|
93
|
+
f"Field {name!r} must be a Mapping or HistoryFieldSpec, "
|
|
94
|
+
f"got {type(spec).__name__}."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Required keys
|
|
98
|
+
if "length" not in spec or "shape" not in spec:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Field {name!r} must define 'length' and 'shape'. Got keys: "
|
|
101
|
+
f"{list(spec.keys())}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Validate length
|
|
105
|
+
length = spec["length"]
|
|
106
|
+
if not isinstance(length, int) or length < 1:
|
|
107
|
+
raise ValueError(f"{name!r}.length must be an int >= 1, got {length}.")
|
|
108
|
+
|
|
109
|
+
# Validate shape
|
|
110
|
+
shape_val = spec["shape"]
|
|
111
|
+
if not isinstance(shape_val, (tuple, list)):
|
|
112
|
+
raise TypeError(
|
|
113
|
+
f"{name!r}.shape must be a tuple/list of ints, "
|
|
114
|
+
f"got {type(shape_val).__name__}."
|
|
115
|
+
)
|
|
116
|
+
shape_tuple = tuple(int(d) for d in shape_val)
|
|
117
|
+
if any(d < 0 for d in shape_tuple):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"{name!r}.shape must contain non-negative ints, got {shape_tuple}."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Optional keys: dtype/init
|
|
123
|
+
|
|
124
|
+
# Validate dtype
|
|
125
|
+
try:
|
|
126
|
+
dtype = jnp.dtype(spec.get("dtype", jnp.float32))
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise TypeError(
|
|
129
|
+
f"{name!r}.dtype is not a valid JAX dtype: {spec.get('dtype')!r}."
|
|
130
|
+
) from e
|
|
131
|
+
|
|
132
|
+
# Validate init
|
|
133
|
+
init = spec.get("init", "zeros")
|
|
134
|
+
if init not in allowed_inits:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"{name!r}.init must be one of {allowed_inits}, got {init!r}."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
out[name] = HistoryFieldSpec(
|
|
140
|
+
length=length, shape=shape_tuple, dtype=dtype, init=init
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return cls(fields=out)
|