robo-goggles 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,127 @@
1
+ """Events routing across files, processes, and machines.
2
+
3
+ This module encapsulates the multi-machine, multi-process routing of events
4
+ via the EventBus class. It uses a client-server model where one process
5
+ acts as the host (server) and others connect to it (clients).
6
+
7
+ Example:
8
+ >>> bus = get_bus()
9
+
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Optional
15
+ import portal
16
+ import socket
17
+ import netifaces
18
+
19
+ from goggles import EventBus, Event, GOGGLES_HOST, GOGGLES_PORT
20
+
21
+ # Singleton factory ---------------------------------------------------------
22
+ __singleton_client: Optional[portal.Client] = None
23
+ __singleton_server: Optional[portal.Server] = None
24
+
25
+
26
+ def __i_am_host() -> bool:
27
+ """Return whether this process is the goggles event bus host.
28
+
29
+ Returns:
30
+ bool: True if this process is the host, False otherwise.
31
+
32
+ """
33
+ # If GOGGLES_HOST is localhost/127.0.0.1, we are always the host
34
+ if GOGGLES_HOST in ("localhost", "127.0.0.1", "::1"):
35
+ return True
36
+
37
+ # Get all local IP addresses
38
+ hostname = socket.gethostname()
39
+ local_ips = set()
40
+
41
+ # Add hostname resolution
42
+ try:
43
+ local_ips.add(socket.gethostbyname(hostname))
44
+ except socket.gaierror:
45
+ pass
46
+
47
+ # Add all interface IPs
48
+ for interface in netifaces.interfaces():
49
+ addrs = netifaces.ifaddresses(interface)
50
+ for addr_family in [netifaces.AF_INET, netifaces.AF_INET6]:
51
+ if addr_family in addrs:
52
+ for addr_info in addrs[addr_family]:
53
+ if "addr" in addr_info:
54
+ local_ips.add(addr_info["addr"])
55
+
56
+ # Check if GOGGLES_HOST matches any local IP
57
+ return GOGGLES_HOST in local_ips
58
+
59
+
60
+ def __is_port_in_use(host: str, port: int) -> bool:
61
+ """Check if a port is already in use.
62
+
63
+ Args:
64
+ host (str): The host to check.
65
+ port (int): The port to check.
66
+
67
+ Returns:
68
+ bool: True if the port is in use, False otherwise.
69
+
70
+ """
71
+ try:
72
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
73
+ sock.settimeout(1) # 1 second timeout
74
+ result = sock.connect_ex((host, port))
75
+ return result == 0 # 0 means connection successful (port in use)
76
+ except Exception:
77
+ return False
78
+
79
+
80
+ def get_bus() -> portal.Client:
81
+ """Return the process-wide EventBus singleton.
82
+
83
+ This function ensures that there is a single instance of the
84
+ EventBus for the entire application, even if distributed across machines.
85
+
86
+ It uses a client-server model where one process acts as the host
87
+ (server) and others connect to it (clients). The host is determined
88
+ based on the GOGGLES_HOST configuration. The methods of EventBus are
89
+ exposed via a portal server for remote invocation.
90
+
91
+ NOTE: It is not thread-safe. It works on multiple machines and multiple
92
+ processes, but it is not guaranteed to work consistently for multiple
93
+ threads within the same process.
94
+
95
+ Returns:
96
+ portal.Client: The singleton EventBus client.
97
+
98
+ """
99
+ if __i_am_host() and not __is_port_in_use(GOGGLES_HOST, int(GOGGLES_PORT)):
100
+ global __singleton_server
101
+ try:
102
+ event_bus = EventBus()
103
+ server = portal.Server(
104
+ GOGGLES_PORT, name=f"EventBus-Server@{socket.gethostname()}"
105
+ )
106
+ server.bind("attach", event_bus.attach)
107
+ server.bind("detach", event_bus.detach)
108
+ server.bind("emit", event_bus.emit)
109
+ server.bind("shutdown", event_bus.shutdown)
110
+ server.start(block=False)
111
+ __singleton_server = server
112
+ except OSError:
113
+ # Fallback: Server creation failed for other reasons
114
+ # (e.g. concurrency), no further need
115
+ pass
116
+
117
+ global __singleton_client
118
+ if __singleton_client is None:
119
+ __singleton_client = portal.Client(
120
+ f"{GOGGLES_HOST}:{GOGGLES_PORT}",
121
+ name=f"EventBus-Client@{socket.gethostname()}",
122
+ )
123
+
124
+ return __singleton_client
125
+
126
+
127
+ __all__ = ["Event", "CoreEventBus", "get_bus", "Handler"]
goggles/config.py ADDED
@@ -0,0 +1,68 @@
1
+ """Utilities for loading and pretty-printing configuration files."""
2
+
3
+ from ruamel.yaml import YAML
4
+ from rich.console import Console
5
+ from rich.pretty import Pretty
6
+ from yaml.representer import SafeRepresenter
7
+
8
+
9
+ class PrettyConfig(dict):
10
+ """Dictionary subclass with pretty-printing using ruamel.yaml."""
11
+
12
+ def __str__(self):
13
+ """Return a pretty-printed string of the configuration."""
14
+ console = Console()
15
+ plain = dict(self)
16
+ with console.capture() as capture:
17
+ console.print(Pretty(plain))
18
+ return capture.get()
19
+
20
+ __repr__ = __str__
21
+
22
+
23
+ def load_configuration(file_path: str) -> PrettyConfig:
24
+ """Load YAML configuration from file and return as PrettyConfig.
25
+
26
+ Args:
27
+ file_path (str): Path to the YAML configuration file.
28
+
29
+ Returns:
30
+ PrettyConfig: A PrettyConfig object containing the loaded configuration.
31
+
32
+ Raises:
33
+ FileNotFoundError: If the specified file does not exist.
34
+
35
+ """
36
+ yaml = YAML(typ="safe", pure=True)
37
+
38
+ with open(file_path, "r", encoding="utf-8") as f:
39
+ data = yaml.load(f) or {}
40
+ # Wrap the loaded dict in our PrettyConfig
41
+ return PrettyConfig(data)
42
+
43
+
44
+ def represent_prettyconfig(dumper, data):
45
+ """Represent PrettyConfig as a YAML mapping.
46
+
47
+ Args:
48
+ dumper: The YAML dumper.
49
+ data: The PrettyConfig instance.
50
+
51
+ """
52
+ return dumper.represent_mapping("tag:yaml.org,2002:map", dict(data))
53
+
54
+
55
+ SafeRepresenter.add_representer(PrettyConfig, represent_prettyconfig)
56
+
57
+
58
+ def save_configuration(config: PrettyConfig, file_path: str):
59
+ """Dump PrettyConfig to a YAML file.
60
+
61
+ Args:
62
+ config (PrettyConfig): The configuration to dump.
63
+ file_path (str): Path to the output YAML file.
64
+
65
+ """
66
+ yaml = YAML(typ="safe", pure=True)
67
+ with open(file_path, "w", encoding="utf-8") as f:
68
+ yaml.dump(dict(config), f)
goggles/decorators.py ADDED
@@ -0,0 +1,81 @@
1
+ """Decorators for logging and timing function execution."""
2
+
3
+ import logging
4
+
5
+
6
+ def timeit(severity=logging.INFO, name=None):
7
+ """Measure the execution time of a function via decorators.
8
+
9
+ Args:
10
+ severity (Severity): Log severity level for timing message.
11
+ name (str): Optional name for the timing entry.
12
+ If None, uses filename:function_name.
13
+
14
+ Example:
15
+ >>> @timeit(severity=Severity.DEBUG, name="my_function_timing")
16
+ ... def my_function():
17
+ ... # function logic here
18
+ ... pass
19
+ >>> my_function()
20
+ DEBUG: my_function_timing took 0.123456s
21
+
22
+ """
23
+ from goggles import GogglesLogger
24
+
25
+ def decorator(func):
26
+ import time
27
+ import os
28
+ from . import get_logger
29
+
30
+ logger: GogglesLogger = get_logger(
31
+ "goggles.decorators.timeit", with_metrics=True
32
+ )
33
+
34
+ def wrapper(*args, **kwargs):
35
+ start = time.perf_counter()
36
+ result = func(*args, **kwargs)
37
+ duration = time.perf_counter() - start
38
+ filename = os.path.basename(func.__code__.co_filename)
39
+ fname = name or f"{filename}:{func.__name__}"
40
+ logger.log(severity, f"{fname} took {duration:.6f}s")
41
+ logger.scalar(f"timings/{fname}", duration)
42
+ return result
43
+
44
+ return wrapper
45
+
46
+ return decorator
47
+
48
+
49
+ def trace_on_error():
50
+ """Trace errors and log function parameters via decorators.
51
+
52
+ Example:
53
+ >>> @trace_on_error()
54
+ ... def my_function(x, y):
55
+ ... return x / y # may raise ZeroDivisionError
56
+ >>> my_function(10, 0)
57
+ ERROR: Exception in my_function: division by zero, state:
58
+ {'args': (10, 0), 'kwargs': {}}
59
+
60
+ """
61
+
62
+ def decorator(func):
63
+ from . import get_logger
64
+
65
+ logger = get_logger("goggles.decorators.trace_on_error")
66
+
67
+ def wrapper(*args, **kwargs):
68
+ try:
69
+ return func(*args, **kwargs)
70
+ except Exception as e:
71
+ # collect parameters
72
+ data = {"args": args, "kwargs": kwargs}
73
+ # if method, collect self attributes
74
+ if args and hasattr(args[0], "__dict__"):
75
+ data["self"] = args[0].__dict__
76
+ logger.error(f"Exception in {func.__name__}: {e}; state: {data}")
77
+ raise
78
+
79
+ return wrapper
80
+
81
+ return decorator
@@ -0,0 +1,39 @@
1
+ """Device-resident temporal history buffers for JAX pipelines.
2
+
3
+ This package provides typed specifications and interfaces for constructing,
4
+ updating, and slicing temporal histories stored on device.
5
+
6
+ Public API:
7
+ - HistoryFieldSpec
8
+ - HistorySpec
9
+ - create_history
10
+ - update_history
11
+ - slice_history
12
+ - peek_last
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ try:
18
+ import jax # noqa: F401
19
+ import jax.numpy as jnp # noqa: F401
20
+ except ImportError as e:
21
+ raise ImportError(
22
+ "The 'goggles.history' module requires JAX. "
23
+ "Install with `pip install goggles[jax]`."
24
+ ) from e
25
+
26
+ from .spec import HistoryFieldSpec, HistorySpec
27
+ from .buffer import create_history, update_history
28
+ from .utils import slice_history, peek_last, to_device, to_host
29
+
30
+ __all__ = [
31
+ "HistoryFieldSpec",
32
+ "HistorySpec",
33
+ "create_history",
34
+ "update_history",
35
+ "slice_history",
36
+ "peek_last",
37
+ "to_device",
38
+ "to_host",
39
+ ]
@@ -0,0 +1,185 @@
1
+ """Creation and update interfaces for device-resident history buffers."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from typing import Dict, Optional
6
+ from .spec import HistorySpec
7
+ from .types import PRNGKey, Array, History
8
+
9
+
10
+ def _apply_reset(
11
+ hist_row: Array,
12
+ new_row: Array,
13
+ reset: Array,
14
+ init_mode: str,
15
+ key: Optional[PRNGKey] = None,
16
+ ) -> Array:
17
+ """Shift and optionally reset a single history row.
18
+
19
+ This uses JAX-friendly control flow (lax.cond) so it can be jitted/vmap'd
20
+ without attempting Python-level boolean conversions of tracers.
21
+
22
+ Args:
23
+ hist_row: Array shaped (T, *shape) for a single batch row.
24
+ new_row: Array shaped (1, *shape), appended at the end along time.
25
+ reset: Boolean scalar (0-dim) indicating whether to reset this row.
26
+ init_mode: One of {"zeros", "ones", "randn", "none"}.
27
+ key: Optional PRNGKey (shape (2,)) used when `init_mode == "randn"`.
28
+
29
+ Returns:
30
+ Array with the same shape as `hist_row`, updated for this step.
31
+
32
+ Raises:
33
+ ValueError: If `init_mode` is unknown.
34
+
35
+ """
36
+ shifted_row = jnp.concatenate([hist_row[1:], new_row], axis=0)
37
+
38
+ if init_mode == "none":
39
+ return shifted_row
40
+
41
+ def do_reset(_):
42
+ if init_mode == "zeros":
43
+ return jnp.zeros_like(hist_row)
44
+ if init_mode == "ones":
45
+ return jnp.ones_like(hist_row)
46
+ if init_mode == "randn":
47
+ return jax.random.normal(key, hist_row.shape, hist_row.dtype) # type: ignore
48
+ raise ValueError(f"Unknown init mode {init_mode!r}")
49
+
50
+ return jax.lax.cond(reset, do_reset, lambda _: shifted_row, operand=None)
51
+
52
+
53
+ def create_history(
54
+ spec: HistorySpec, batch_size: int, rng: Optional[PRNGKey] = None
55
+ ) -> History:
56
+ """Allocate device-resident history tensors following (B, T, *shape).
57
+
58
+ Args:
59
+ spec (HistorySpec): Describing each field.
60
+ batch_size (int): Batch size (B).
61
+ rng (Optional[PRNGKey]): Optional PRNG key for randomized initialization
62
+ of the buffers (e.g., for initial values or noise).
63
+
64
+ Returns:
65
+ dict (History): Mapping field name to array shaped (B, T, *shape).
66
+
67
+ Raises:
68
+ ValueError: If batch_size <= 0 or invalid spec values.
69
+
70
+ """
71
+ if batch_size <= 0:
72
+ raise ValueError("batch_size must be > 0")
73
+
74
+ history: History = {}
75
+ for name, field in spec.fields.items():
76
+ # Validate length
77
+ if field.length <= 0:
78
+ raise ValueError(f"Invalid history length for field '{name}'")
79
+ shape = (batch_size, field.length, *field.shape)
80
+
81
+ # Initialize according to policy
82
+ if field.init == "zeros":
83
+ arr = jnp.zeros(shape, field.dtype)
84
+ elif field.init == "ones":
85
+ arr = jnp.ones(shape, field.dtype)
86
+ elif field.init == "randn":
87
+ if rng is None:
88
+ raise ValueError(f"Field '{name}' requires rng for randn init")
89
+ rng, sub = jax.random.split(rng)
90
+ arr = jax.random.normal(sub, shape, field.dtype)
91
+ elif field.init == "none":
92
+ arr = jnp.empty(shape, field.dtype)
93
+ else:
94
+ raise ValueError(f"Unknown init mode {field.init!r} for field '{name}'")
95
+ history[name] = arr
96
+ return history
97
+
98
+
99
+ def update_history(
100
+ history: History,
101
+ new_data: Dict[str, Array],
102
+ reset_mask: Optional[Array] = None,
103
+ spec: Optional[HistorySpec] = None,
104
+ rng: Optional[jax.Array] = None,
105
+ ) -> History:
106
+ """Shift and append new items along the temporal axis.
107
+
108
+ Note: this function can be jitted and vmapped over batch dimensions. RNG handling:
109
+ if `rng` is provided, it may be either a single PRNGKey or an array of per-batch
110
+ keys with shape (B, 2). This lets callers supply already-sharded keys for
111
+ multi-device/pmap scenarios.
112
+
113
+ Args:
114
+ history (History): Current history dict (B, T, *shape).
115
+ new_data (Dict[str, Array]): New entries per field, shaped (B, 1, *shape).
116
+ reset_mask (Optional[Array]): Optional boolean mask for resets (B,).
117
+ spec (Optional[HistorySpec]): Optional spec describing reset initialization.
118
+ rng (Optional[jax.Array]): Optional PRNG key for randomized resets.
119
+
120
+ Returns:
121
+ History: Updated history dict.
122
+
123
+ Raises:
124
+ ValueError: If shapes, dtypes, or append lengths are invalid.
125
+
126
+ """
127
+ updated: History = {}
128
+
129
+ for name, hist in history.items():
130
+ if name not in new_data:
131
+ raise ValueError(f"Missing new data for field '{name}'")
132
+ new = new_data[name]
133
+
134
+ # Validate shapes/dtypes
135
+ if new.ndim != hist.ndim:
136
+ raise ValueError(
137
+ f"Dim mismatch for field '{name}': {new.shape} vs {hist.shape}"
138
+ )
139
+ if new.shape[1] != 1:
140
+ raise ValueError(f"Append length must be 1 for field '{name}'")
141
+ if new.dtype != hist.dtype:
142
+ raise ValueError(f"Dtype mismatch for field '{name}'")
143
+
144
+ # Determine init mode for resets
145
+ if spec is not None and hasattr(spec, "fields") and name in spec.fields:
146
+ init_mode = spec.fields[name].init
147
+ else:
148
+ init_mode = "zeros"
149
+
150
+ # Fast path: no reset handling requested.
151
+ if reset_mask is None:
152
+ updated_field = jnp.concatenate([hist[:, 1:, ...], new], axis=1)
153
+ updated[name] = updated_field
154
+ continue
155
+
156
+ # Validate reset mask shape.
157
+ if reset_mask.ndim != 1 or reset_mask.shape[0] != hist.shape[0]:
158
+ raise ValueError(
159
+ f"Invalid reset_mask shape {reset_mask.shape}, expected (B,)"
160
+ )
161
+
162
+ # Prepare per-batch keys when needed.
163
+ if init_mode == "randn":
164
+ if rng is None:
165
+ raise ValueError(f"Field '{name}' requires rng for randn reset")
166
+ rng_arr = jnp.asarray(rng)
167
+ if rng_arr.ndim == 1: # single key (2,)
168
+ keys = jax.random.split(rng_arr, hist.shape[0])
169
+ elif rng_arr.ndim == 2 and rng_arr.shape[0] == hist.shape[0]:
170
+ keys = rng_arr
171
+ else:
172
+ raise ValueError(
173
+ "rng must be a PRNGKey (shape (2,)) or per-batch keys with shape "
174
+ f"(B, 2); got {tuple(rng_arr.shape)}"
175
+ )
176
+ else:
177
+ # Dummy keys; ignored unless init_mode == 'randn'.
178
+ keys = jnp.zeros((hist.shape[0], 2), dtype=jnp.uint32)
179
+
180
+ # Vmap over batch. Keep new with time-dim = 1 for concat in helper.
181
+ apply = lambda h, n, r, k: _apply_reset(h, n, r, init_mode, k)
182
+ updated_field = jax.vmap(apply)(hist, new[:, 0:1, ...], reset_mask, keys)
183
+ updated[name] = updated_field
184
+
185
+ return updated
@@ -0,0 +1,143 @@
1
+ """Type specifications for device-resident history buffers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, Literal, Mapping, Tuple
7
+
8
+ import jax.numpy as jnp
9
+
10
+ InitMode = Literal["zeros", "ones", "randn", "none"]
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class HistoryFieldSpec:
15
+ """Describe one temporal field stored on device.
16
+
17
+ Attributes:
18
+ length (int): Number of stored timesteps for this field.
19
+ shape (Tuple[int, ...]): Per-timestep payload shape (no batch/time dims).
20
+ dtype (jnp.dtype): Array dtype.
21
+ init (InitMode): Initialization policy ("zeros" | "ones" | "randn" | "none").
22
+
23
+ """
24
+
25
+ length: int
26
+ shape: tuple[int, ...]
27
+ dtype: jnp.dtype = jnp.float32
28
+ init: InitMode = "zeros"
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class HistorySpec:
33
+ """Bundle multiple named history field specs.
34
+
35
+ Attributes:
36
+ fields (Mapping[str, HistoryFieldSpec]): Mapping from field name to spec
37
+
38
+ """
39
+
40
+ fields: Mapping[str, HistoryFieldSpec]
41
+
42
+ @classmethod
43
+ def from_config(cls, config: Mapping[str, Any]) -> "HistorySpec":
44
+ """Construct a HistorySpec from a nested config dictionary.
45
+
46
+ Args:
47
+ config (Mapping[str, Any]): Dict mapping field name to kwargs for
48
+ `HistoryFieldSpec` or to an already-built `HistoryFieldSpec`. Each
49
+ kwargs dict must include:
50
+ - "length" (int): Number of timesteps (T >= 1).
51
+ - "shape" (Sequence[int] | tuple[int, ...]): Per-timestep shape.
52
+ Optional keys:
53
+ - "dtype": Anything accepted by `jnp.dtype` (default float32).
54
+ - "init": One of {"zeros", "ones", "randn", "none"}.
55
+
56
+ Returns:
57
+ HistorySpec: Parsed specification bundle.
58
+
59
+ Raises:
60
+ TypeError: If `config` is not a mapping, or a field entry has
61
+ unsupported type, or shapes/dtypes have invalid types.
62
+ ValueError: If required keys are missing or values are invalid
63
+ (e.g., length < 1, negative dims, unknown init mode).
64
+
65
+ """
66
+ if not isinstance(config, Mapping):
67
+ raise TypeError("config must be a Mapping[str, Any].")
68
+
69
+ allowed_inits: Tuple[str, ...] = ("zeros", "ones", "randn", "none")
70
+ out: Dict[str, HistoryFieldSpec] = {}
71
+
72
+ for name, spec in config.items():
73
+ if isinstance(spec, HistoryFieldSpec):
74
+ # Validate basic invariants even if user provided an instance.
75
+ if not isinstance(spec.length, int) or spec.length < 1:
76
+ raise ValueError(
77
+ f"{name!r}.length must be an int >= 1, got {spec.length}."
78
+ )
79
+ if any((not isinstance(d, int) or d < 0) for d in spec.shape):
80
+ raise ValueError(
81
+ f"{name!r}.shape must be a tuple of non-negative ints, "
82
+ f"got {spec.shape}."
83
+ )
84
+ if spec.init not in allowed_inits:
85
+ raise ValueError(
86
+ f"{name!r}.init must be one of {allowed_inits}, got {spec.init}."
87
+ )
88
+ out[name] = spec
89
+ continue
90
+
91
+ if not isinstance(spec, Mapping):
92
+ raise TypeError(
93
+ f"Field {name!r} must be a Mapping or HistoryFieldSpec, "
94
+ f"got {type(spec).__name__}."
95
+ )
96
+
97
+ # Required keys
98
+ if "length" not in spec or "shape" not in spec:
99
+ raise ValueError(
100
+ f"Field {name!r} must define 'length' and 'shape'. Got keys: "
101
+ f"{list(spec.keys())}"
102
+ )
103
+
104
+ # Validate length
105
+ length = spec["length"]
106
+ if not isinstance(length, int) or length < 1:
107
+ raise ValueError(f"{name!r}.length must be an int >= 1, got {length}.")
108
+
109
+ # Validate shape
110
+ shape_val = spec["shape"]
111
+ if not isinstance(shape_val, (tuple, list)):
112
+ raise TypeError(
113
+ f"{name!r}.shape must be a tuple/list of ints, "
114
+ f"got {type(shape_val).__name__}."
115
+ )
116
+ shape_tuple = tuple(int(d) for d in shape_val)
117
+ if any(d < 0 for d in shape_tuple):
118
+ raise ValueError(
119
+ f"{name!r}.shape must contain non-negative ints, got {shape_tuple}."
120
+ )
121
+
122
+ # Optional keys: dtype/init
123
+
124
+ # Validate dtype
125
+ try:
126
+ dtype = jnp.dtype(spec.get("dtype", jnp.float32))
127
+ except Exception as e:
128
+ raise TypeError(
129
+ f"{name!r}.dtype is not a valid JAX dtype: {spec.get('dtype')!r}."
130
+ ) from e
131
+
132
+ # Validate init
133
+ init = spec.get("init", "zeros")
134
+ if init not in allowed_inits:
135
+ raise ValueError(
136
+ f"{name!r}.init must be one of {allowed_inits}, got {init!r}."
137
+ )
138
+
139
+ out[name] = HistoryFieldSpec(
140
+ length=length, shape=shape_tuple, dtype=dtype, init=init
141
+ )
142
+
143
+ return cls(fields=out)
@@ -0,0 +1,9 @@
1
+ """Shared type aliases for history package."""
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict
5
+ import jax.numpy as jnp
6
+
7
+ PRNGKey = jnp.ndarray
8
+ Array = jnp.ndarray
9
+ History = Dict[str, Array]