dynestyx 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dynestyx/__init__.py +49 -0
- dynestyx/diagnostics/__init__.py +1 -0
- dynestyx/diagnostics/plotting_utils.py +434 -0
- dynestyx/discretizers.py +206 -0
- dynestyx/handlers.py +155 -0
- dynestyx/inference/__init__.py +7 -0
- dynestyx/inference/filter_configs.py +675 -0
- dynestyx/inference/filters.py +352 -0
- dynestyx/inference/hmm_filters.py +241 -0
- dynestyx/inference/integrations/__init__.py +3 -0
- dynestyx/inference/integrations/blackjax/__init__.py +5 -0
- dynestyx/inference/integrations/blackjax/mcmc.py +319 -0
- dynestyx/inference/integrations/cd_dynamax/__init__.py +1 -0
- dynestyx/inference/integrations/cd_dynamax/continuous.py +266 -0
- dynestyx/inference/integrations/cd_dynamax/discrete.py +316 -0
- dynestyx/inference/integrations/cd_dynamax/utils.py +440 -0
- dynestyx/inference/integrations/cuthbert/__init__.py +1 -0
- dynestyx/inference/integrations/cuthbert/discrete.py +458 -0
- dynestyx/inference/integrations/utils.py +83 -0
- dynestyx/inference/mcmc.py +158 -0
- dynestyx/inference/mcmc_configs.py +113 -0
- dynestyx/models/__init__.py +39 -0
- dynestyx/models/checkers.py +163 -0
- dynestyx/models/core.py +376 -0
- dynestyx/models/lti_dynamics.py +189 -0
- dynestyx/models/observations.py +112 -0
- dynestyx/models/state_evolution.py +147 -0
- dynestyx/simulators.py +1044 -0
- dynestyx/types.py +14 -0
- dynestyx/utils.py +223 -0
- dynestyx-0.0.3.dist-info/METADATA +76 -0
- dynestyx-0.0.3.dist-info/RECORD +34 -0
- dynestyx-0.0.3.dist-info/WHEEL +4 -0
- dynestyx-0.0.3.dist-info/licenses/LICENSE.md +202 -0
dynestyx/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Dynestyx package."""
|
|
2
|
+
|
|
3
|
+
from dynestyx.discretizers import Discretizer, euler_maruyama
|
|
4
|
+
from dynestyx.handlers import sample
|
|
5
|
+
from dynestyx.inference.filters import Filter
|
|
6
|
+
from dynestyx.models import (
|
|
7
|
+
ContinuousTimeStateEvolution,
|
|
8
|
+
DiracIdentityObservation,
|
|
9
|
+
DiscreteTimeStateEvolution,
|
|
10
|
+
DynamicalModel,
|
|
11
|
+
GaussianObservation,
|
|
12
|
+
GaussianStateEvolution,
|
|
13
|
+
LinearGaussianObservation,
|
|
14
|
+
LinearGaussianStateEvolution,
|
|
15
|
+
LTI_continuous,
|
|
16
|
+
LTI_discrete,
|
|
17
|
+
ObservationModel,
|
|
18
|
+
)
|
|
19
|
+
from dynestyx.simulators import (
|
|
20
|
+
DiscreteTimeSimulator,
|
|
21
|
+
ODESimulator,
|
|
22
|
+
SDESimulator,
|
|
23
|
+
Simulator,
|
|
24
|
+
)
|
|
25
|
+
from dynestyx.utils import flatten_draws
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ContinuousTimeStateEvolution",
|
|
29
|
+
"DiscreteTimeStateEvolution",
|
|
30
|
+
"DynamicalModel",
|
|
31
|
+
"AffineDrift",
|
|
32
|
+
"LTI_continuous",
|
|
33
|
+
"LTI_discrete",
|
|
34
|
+
"LinearGaussianStateEvolution",
|
|
35
|
+
"GaussianStateEvolution",
|
|
36
|
+
"Discretizer",
|
|
37
|
+
"ObservationModel",
|
|
38
|
+
"Filter",
|
|
39
|
+
"flatten_draws",
|
|
40
|
+
"sample",
|
|
41
|
+
"DiracIdentityObservation",
|
|
42
|
+
"LinearGaussianObservation",
|
|
43
|
+
"GaussianObservation",
|
|
44
|
+
"DiscreteTimeSimulator",
|
|
45
|
+
"ODESimulator",
|
|
46
|
+
"SDESimulator",
|
|
47
|
+
"Simulator",
|
|
48
|
+
"euler_maruyama",
|
|
49
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Diagnostics utilities for dynestyx."""
|
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
# HMM
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def plot_hmm_states_and_observations(
|
|
9
|
+
times,
|
|
10
|
+
x,
|
|
11
|
+
y,
|
|
12
|
+
state_cmap="tab10",
|
|
13
|
+
obs_cmap="Set1",
|
|
14
|
+
show_fig=False,
|
|
15
|
+
save_path=None,
|
|
16
|
+
obs_style="auto",
|
|
17
|
+
obs_marker="x",
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Plot latent discrete HMM states as colored background bands
|
|
21
|
+
with observed signals overlaid.
|
|
22
|
+
|
|
23
|
+
:param times: (T,) Time points
|
|
24
|
+
:param x: (T,) Discrete latent state indices (0..K-1)
|
|
25
|
+
:param y: (T,) or (T, N_obs) Observations
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
times = np.asarray(times)
|
|
29
|
+
x = np.asarray(x)
|
|
30
|
+
y = np.asarray(y)
|
|
31
|
+
|
|
32
|
+
T = len(times)
|
|
33
|
+
if x.shape[0] != T:
|
|
34
|
+
raise ValueError(f"`x` must have shape (T,), got {x.shape} with T={T}.")
|
|
35
|
+
if y.shape[0] != T:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"`y` must have shape (T,) or (T, N_obs), got {y.shape} with T={T}."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# ---- Normalize observation shape ----
|
|
41
|
+
if y.ndim == 1:
|
|
42
|
+
y = y[:, None] # (T, 1)
|
|
43
|
+
|
|
44
|
+
N_obs = y.shape[1]
|
|
45
|
+
|
|
46
|
+
# ---- Discrete state labels (may not be 0..K-1) ----
|
|
47
|
+
state_values = np.unique(x)
|
|
48
|
+
K = int(state_values.size)
|
|
49
|
+
state_to_idx = {int(s): i for i, s in enumerate(state_values.tolist())}
|
|
50
|
+
|
|
51
|
+
# ---- Time "edges" for clean contiguous state bands ----
|
|
52
|
+
# For irregular sampling, use midpoints between times; extend at ends by half-step.
|
|
53
|
+
if T == 1:
|
|
54
|
+
dt = 1.0
|
|
55
|
+
edges = np.array([times[0] - 0.5 * dt, times[0] + 0.5 * dt])
|
|
56
|
+
else:
|
|
57
|
+
mids = 0.5 * (times[:-1] + times[1:])
|
|
58
|
+
left = times[0] - 0.5 * (times[1] - times[0])
|
|
59
|
+
right = times[-1] + 0.5 * (times[-1] - times[-2])
|
|
60
|
+
edges = np.concatenate(([left], mids, [right]))
|
|
61
|
+
|
|
62
|
+
# ---- Color maps ----
|
|
63
|
+
cmap_states = plt.cm.get_cmap(state_cmap, K)
|
|
64
|
+
state_colors = [cmap_states(k) for k in range(K)]
|
|
65
|
+
|
|
66
|
+
cmap_obs = plt.cm.get_cmap(obs_cmap, N_obs)
|
|
67
|
+
obs_colors = [cmap_obs(i) for i in range(N_obs)]
|
|
68
|
+
|
|
69
|
+
fig, ax = plt.subplots(figsize=(10, 4))
|
|
70
|
+
|
|
71
|
+
# ---- Draw state background as contiguous segments ----
|
|
72
|
+
def draw_state_blocks():
|
|
73
|
+
start = 0
|
|
74
|
+
for t in range(1, T + 1):
|
|
75
|
+
if t == T or x[t] != x[start]:
|
|
76
|
+
s_val = int(x[start])
|
|
77
|
+
k = state_to_idx[s_val]
|
|
78
|
+
ax.axvspan(
|
|
79
|
+
edges[start],
|
|
80
|
+
edges[t],
|
|
81
|
+
color=state_colors[k],
|
|
82
|
+
alpha=0.18,
|
|
83
|
+
linewidth=0,
|
|
84
|
+
)
|
|
85
|
+
start = t
|
|
86
|
+
|
|
87
|
+
draw_state_blocks()
|
|
88
|
+
|
|
89
|
+
# ---- Choose observation style ----
|
|
90
|
+
# If observations are discrete-valued, lines look misleading; default to scatter.
|
|
91
|
+
def _is_discrete_column(col: np.ndarray) -> bool:
|
|
92
|
+
if np.issubdtype(col.dtype, np.integer) or np.issubdtype(col.dtype, np.bool_):
|
|
93
|
+
return True
|
|
94
|
+
# Heuristic: "few unique values" relative to length suggests discrete categories.
|
|
95
|
+
# (Keeps continuous floats like SDE outputs as lines.)
|
|
96
|
+
unique = np.unique(col)
|
|
97
|
+
return unique.size <= min(20, max(3, T // 5))
|
|
98
|
+
|
|
99
|
+
if obs_style not in {"auto", "line", "scatter"}:
|
|
100
|
+
raise ValueError("`obs_style` must be one of {'auto','line','scatter'}.")
|
|
101
|
+
|
|
102
|
+
# ---- Plot observations ----
|
|
103
|
+
for n in range(N_obs):
|
|
104
|
+
col = y[:, n]
|
|
105
|
+
style = obs_style
|
|
106
|
+
if style == "auto":
|
|
107
|
+
style = "scatter" if _is_discrete_column(col) else "line"
|
|
108
|
+
|
|
109
|
+
if style == "line":
|
|
110
|
+
ax.plot(
|
|
111
|
+
times,
|
|
112
|
+
col,
|
|
113
|
+
color=obs_colors[n],
|
|
114
|
+
lw=2,
|
|
115
|
+
label=f"obs[{n}]",
|
|
116
|
+
zorder=5,
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
ax.scatter(
|
|
120
|
+
times,
|
|
121
|
+
col,
|
|
122
|
+
color=obs_colors[n],
|
|
123
|
+
marker=obs_marker,
|
|
124
|
+
s=35,
|
|
125
|
+
linewidths=1.5,
|
|
126
|
+
label=f"obs[{n}]",
|
|
127
|
+
zorder=6,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# ---- Formatting ----
|
|
131
|
+
ax.set_xlabel("Time")
|
|
132
|
+
ax.set_ylabel("Observations")
|
|
133
|
+
ax.set_title("HMM latent states and observations")
|
|
134
|
+
|
|
135
|
+
ax.grid(True, alpha=0.3)
|
|
136
|
+
ax.legend(frameon=False)
|
|
137
|
+
|
|
138
|
+
# ---- Build state legend separately ----
|
|
139
|
+
from matplotlib.patches import Patch
|
|
140
|
+
|
|
141
|
+
state_patches = [
|
|
142
|
+
Patch(
|
|
143
|
+
facecolor=state_colors[state_to_idx[int(s)]],
|
|
144
|
+
alpha=0.3,
|
|
145
|
+
label=f"state {int(s)}",
|
|
146
|
+
)
|
|
147
|
+
for s in state_values
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
ax.legend(
|
|
151
|
+
handles=state_patches + ax.get_legend_handles_labels()[0],
|
|
152
|
+
loc="upper left",
|
|
153
|
+
frameon=False,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
plt.tight_layout()
|
|
157
|
+
|
|
158
|
+
if save_path is not None:
|
|
159
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
160
|
+
plt.close()
|
|
161
|
+
elif show_fig:
|
|
162
|
+
plt.show()
|
|
163
|
+
|
|
164
|
+
return fig, ax
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def plot_continuous_states_and_partial_observations(
|
|
168
|
+
times, x, y, show_fig=False, save_path=None
|
|
169
|
+
):
|
|
170
|
+
"""
|
|
171
|
+
Plot continuous latent states with partial noisy observations.
|
|
172
|
+
|
|
173
|
+
:param times: (T,) Time points
|
|
174
|
+
:param x: (T, state_dim) Continuous latent states
|
|
175
|
+
:param y: (T, obs_dim) Observations
|
|
176
|
+
:param show_fig: Whether to show the figure
|
|
177
|
+
:param save_path: Optional path to save the figure
|
|
178
|
+
"""
|
|
179
|
+
times = np.asarray(times)
|
|
180
|
+
x = np.asarray(jnp.asarray(x))
|
|
181
|
+
y = np.asarray(jnp.asarray(y))
|
|
182
|
+
|
|
183
|
+
T, num_x = x.shape
|
|
184
|
+
num_y = y.shape[1]
|
|
185
|
+
|
|
186
|
+
# Colors
|
|
187
|
+
state_color = "C0"
|
|
188
|
+
obs_color = "C2"
|
|
189
|
+
|
|
190
|
+
# Figure
|
|
191
|
+
fig, axes = plt.subplots(
|
|
192
|
+
num_x, 1, figsize=(10, 2.2 * num_x), sharex=True, constrained_layout=True
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if num_x == 1:
|
|
196
|
+
axes = [axes]
|
|
197
|
+
|
|
198
|
+
# Plot
|
|
199
|
+
for i, ax in enumerate(axes):
|
|
200
|
+
# Latent state
|
|
201
|
+
is_first_state = i == 0
|
|
202
|
+
ax.plot(
|
|
203
|
+
times,
|
|
204
|
+
x[:, i],
|
|
205
|
+
color=state_color,
|
|
206
|
+
lw=2.0,
|
|
207
|
+
alpha=0.95,
|
|
208
|
+
label="Latent state" if is_first_state else None,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Observations (assume first num_y states are observed)
|
|
212
|
+
if i < num_y:
|
|
213
|
+
is_first_obs = i == 0
|
|
214
|
+
ax.scatter(
|
|
215
|
+
times,
|
|
216
|
+
y[:, i],
|
|
217
|
+
s=28,
|
|
218
|
+
facecolors="none",
|
|
219
|
+
edgecolors=obs_color,
|
|
220
|
+
linewidth=1.0,
|
|
221
|
+
alpha=0.7,
|
|
222
|
+
zorder=3,
|
|
223
|
+
label="Observation" if is_first_obs else None,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
ax.set_ylabel(f"x{i + 1}")
|
|
227
|
+
ax.grid(True, alpha=0.3)
|
|
228
|
+
|
|
229
|
+
axes[-1].set_xlabel("Time")
|
|
230
|
+
|
|
231
|
+
# Legend
|
|
232
|
+
axes[0].legend(loc="upper right", frameon=False, ncol=2)
|
|
233
|
+
|
|
234
|
+
plt.tight_layout()
|
|
235
|
+
|
|
236
|
+
if save_path is not None:
|
|
237
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
238
|
+
plt.close()
|
|
239
|
+
elif show_fig:
|
|
240
|
+
plt.show()
|
|
241
|
+
|
|
242
|
+
return fig, axes
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def plot_drift_field(
|
|
246
|
+
f_true,
|
|
247
|
+
f_learned,
|
|
248
|
+
f_learned_sd=None,
|
|
249
|
+
x1_range=(-3.0, 3.0),
|
|
250
|
+
x2_range=(-3.2, 3.2),
|
|
251
|
+
num_points=50,
|
|
252
|
+
return_rmse=False,
|
|
253
|
+
relative_error=False,
|
|
254
|
+
trajectory=None,
|
|
255
|
+
trajectory_axes="error",
|
|
256
|
+
trajectory_color="red",
|
|
257
|
+
trajectory_lw=1.5,
|
|
258
|
+
trajectory_alpha=0.85,
|
|
259
|
+
):
|
|
260
|
+
"""
|
|
261
|
+
Plot true vs learned drift fields (2D state space).
|
|
262
|
+
Optionally include learned uncertainty (stddev) and/or overlay a data trajectory.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
f_true: callable f(x) -> (2,) array, true drift function
|
|
266
|
+
f_learned: callable f(x) -> (2,) array, learned drift function
|
|
267
|
+
f_learned_sd: optional callable f(x) -> (2,) array (stddev per output dim)
|
|
268
|
+
x1_range: tuple (low, high) for x1 axis
|
|
269
|
+
x2_range: tuple (low, high) for x2 axis
|
|
270
|
+
num_points: number of grid points per axis
|
|
271
|
+
return_rmse: if True, return (fig, rmse)
|
|
272
|
+
relative_error: if True, plot relative error instead of absolute
|
|
273
|
+
trajectory: optional (T, 2) array of (x1, x2) points to overlay (e.g. data path)
|
|
274
|
+
trajectory_axes: "error" (overlay on error panels only) or "all"
|
|
275
|
+
trajectory_color: color for trajectory line
|
|
276
|
+
trajectory_lw: linewidth for trajectory
|
|
277
|
+
trajectory_alpha: alpha for trajectory
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
fig, or (fig, rmse) if return_rmse is True
|
|
281
|
+
"""
|
|
282
|
+
x1 = jnp.linspace(x1_range[0], x1_range[1], num_points)
|
|
283
|
+
x2 = jnp.linspace(x2_range[0], x2_range[1], num_points)
|
|
284
|
+
X1, X2 = jnp.meshgrid(x1, x2, indexing="ij")
|
|
285
|
+
grid_points = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
|
|
286
|
+
|
|
287
|
+
f_true_vals = jax.vmap(f_true)(grid_points)
|
|
288
|
+
f_learned_vals = jax.vmap(f_learned)(grid_points)
|
|
289
|
+
|
|
290
|
+
if f_learned_sd is not None:
|
|
291
|
+
f_learned_sd_vals = jax.vmap(f_learned_sd)(grid_points)
|
|
292
|
+
if f_learned_sd_vals.ndim == 3 and f_learned_sd_vals.shape[1] == 1:
|
|
293
|
+
f_learned_sd_vals = f_learned_sd_vals.squeeze(1)
|
|
294
|
+
f1_sd = np.asarray(f_learned_sd_vals[:, 0].reshape(num_points, num_points))
|
|
295
|
+
f2_sd = np.asarray(f_learned_sd_vals[:, 1].reshape(num_points, num_points))
|
|
296
|
+
else:
|
|
297
|
+
f1_sd = f2_sd = None
|
|
298
|
+
|
|
299
|
+
f1_true = np.asarray(f_true_vals[:, 0].reshape(num_points, num_points))
|
|
300
|
+
f2_true = np.asarray(f_true_vals[:, 1].reshape(num_points, num_points))
|
|
301
|
+
f1_learned = np.asarray(f_learned_vals[:, 0].reshape(num_points, num_points))
|
|
302
|
+
f2_learned = np.asarray(f_learned_vals[:, 1].reshape(num_points, num_points))
|
|
303
|
+
|
|
304
|
+
f1_err = np.abs(f1_learned - f1_true)
|
|
305
|
+
f2_err = np.abs(f2_learned - f2_true)
|
|
306
|
+
if relative_error:
|
|
307
|
+
f1_err /= np.abs(f1_true) + 1e-6
|
|
308
|
+
f2_err /= np.abs(f2_true) + 1e-6
|
|
309
|
+
|
|
310
|
+
vlim1 = float(np.max(np.abs(np.concatenate([f1_true.ravel(), f1_learned.ravel()]))))
|
|
311
|
+
vlim2 = float(np.max(np.abs(np.concatenate([f2_true.ravel(), f2_learned.ravel()]))))
|
|
312
|
+
|
|
313
|
+
ncols = 4 if f_learned_sd is not None else 3
|
|
314
|
+
fig, axes = plt.subplots(2, ncols, figsize=(5 * ncols, 8), constrained_layout=True)
|
|
315
|
+
|
|
316
|
+
im0 = axes[0, 0].imshow(
|
|
317
|
+
f1_true.T,
|
|
318
|
+
origin="lower",
|
|
319
|
+
extent=(*x1_range, *x2_range),
|
|
320
|
+
cmap="seismic",
|
|
321
|
+
vmin=-vlim1,
|
|
322
|
+
vmax=vlim1,
|
|
323
|
+
aspect="auto",
|
|
324
|
+
)
|
|
325
|
+
axes[0, 0].set_title("f1 true")
|
|
326
|
+
fig.colorbar(im0, ax=axes[0, 0], fraction=0.046, pad=0.04)
|
|
327
|
+
|
|
328
|
+
im1 = axes[0, 1].imshow(
|
|
329
|
+
f1_learned.T,
|
|
330
|
+
origin="lower",
|
|
331
|
+
extent=(*x1_range, *x2_range),
|
|
332
|
+
cmap="seismic",
|
|
333
|
+
vmin=-vlim1,
|
|
334
|
+
vmax=vlim1,
|
|
335
|
+
aspect="auto",
|
|
336
|
+
)
|
|
337
|
+
axes[0, 1].set_title("f1 learned")
|
|
338
|
+
fig.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
|
|
339
|
+
|
|
340
|
+
im2 = axes[0, 2].imshow(
|
|
341
|
+
f1_err.T,
|
|
342
|
+
origin="lower",
|
|
343
|
+
extent=(*x1_range, *x2_range),
|
|
344
|
+
cmap="viridis",
|
|
345
|
+
aspect="auto",
|
|
346
|
+
)
|
|
347
|
+
axes[0, 2].set_title("f1 error")
|
|
348
|
+
fig.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
|
|
349
|
+
|
|
350
|
+
if f1_sd is not None:
|
|
351
|
+
im3 = axes[0, 3].imshow(
|
|
352
|
+
f1_sd.T,
|
|
353
|
+
origin="lower",
|
|
354
|
+
extent=(*x1_range, *x2_range),
|
|
355
|
+
cmap="magma",
|
|
356
|
+
aspect="auto",
|
|
357
|
+
)
|
|
358
|
+
axes[0, 3].set_title("f1 stddev")
|
|
359
|
+
fig.colorbar(im3, ax=axes[0, 3], fraction=0.046, pad=0.04)
|
|
360
|
+
|
|
361
|
+
im4 = axes[1, 0].imshow(
|
|
362
|
+
f2_true.T,
|
|
363
|
+
origin="lower",
|
|
364
|
+
extent=(*x1_range, *x2_range),
|
|
365
|
+
cmap="seismic",
|
|
366
|
+
vmin=-vlim2,
|
|
367
|
+
vmax=vlim2,
|
|
368
|
+
aspect="auto",
|
|
369
|
+
)
|
|
370
|
+
axes[1, 0].set_title("f2 true")
|
|
371
|
+
fig.colorbar(im4, ax=axes[1, 0], fraction=0.046, pad=0.04)
|
|
372
|
+
|
|
373
|
+
im5 = axes[1, 1].imshow(
|
|
374
|
+
f2_learned.T,
|
|
375
|
+
origin="lower",
|
|
376
|
+
extent=(*x1_range, *x2_range),
|
|
377
|
+
cmap="seismic",
|
|
378
|
+
vmin=-vlim2,
|
|
379
|
+
vmax=vlim2,
|
|
380
|
+
aspect="auto",
|
|
381
|
+
)
|
|
382
|
+
axes[1, 1].set_title("f2 learned")
|
|
383
|
+
fig.colorbar(im5, ax=axes[1, 1], fraction=0.046, pad=0.04)
|
|
384
|
+
|
|
385
|
+
im6 = axes[1, 2].imshow(
|
|
386
|
+
f2_err.T,
|
|
387
|
+
origin="lower",
|
|
388
|
+
extent=(*x1_range, *x2_range),
|
|
389
|
+
cmap="viridis",
|
|
390
|
+
aspect="auto",
|
|
391
|
+
)
|
|
392
|
+
axes[1, 2].set_title("f2 error")
|
|
393
|
+
fig.colorbar(im6, ax=axes[1, 2], fraction=0.046, pad=0.04)
|
|
394
|
+
|
|
395
|
+
if f2_sd is not None:
|
|
396
|
+
im7 = axes[1, 3].imshow(
|
|
397
|
+
f2_sd.T,
|
|
398
|
+
origin="lower",
|
|
399
|
+
extent=(*x1_range, *x2_range),
|
|
400
|
+
cmap="magma",
|
|
401
|
+
aspect="auto",
|
|
402
|
+
)
|
|
403
|
+
axes[1, 3].set_title("f2 stddev")
|
|
404
|
+
fig.colorbar(im7, ax=axes[1, 3], fraction=0.046, pad=0.04)
|
|
405
|
+
|
|
406
|
+
for ax in axes.ravel():
|
|
407
|
+
ax.set_xlabel("x1")
|
|
408
|
+
ax.set_ylabel("x2")
|
|
409
|
+
ax.grid(False)
|
|
410
|
+
|
|
411
|
+
if trajectory is not None:
|
|
412
|
+
traj = np.asarray(trajectory)
|
|
413
|
+
if traj.ndim != 2 or traj.shape[1] != 2:
|
|
414
|
+
raise ValueError("trajectory must have shape (T, 2) for (x1, x2)")
|
|
415
|
+
if trajectory_axes == "error":
|
|
416
|
+
overlay_axes = [axes[0, 2], axes[1, 2]]
|
|
417
|
+
elif trajectory_axes == "all":
|
|
418
|
+
overlay_axes = list(axes.ravel())
|
|
419
|
+
else:
|
|
420
|
+
raise ValueError('trajectory_axes must be "error" or "all"')
|
|
421
|
+
for ax in overlay_axes:
|
|
422
|
+
ax.plot(
|
|
423
|
+
traj[:, 0],
|
|
424
|
+
traj[:, 1],
|
|
425
|
+
color=trajectory_color,
|
|
426
|
+
lw=trajectory_lw,
|
|
427
|
+
alpha=trajectory_alpha,
|
|
428
|
+
zorder=5,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if return_rmse:
|
|
432
|
+
rmse = float(jnp.sqrt(jnp.mean((f_learned_vals - f_true_vals) ** 2)))
|
|
433
|
+
return fig, rmse
|
|
434
|
+
return fig
|
dynestyx/discretizers.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpyro.distributions as dist
|
|
3
|
+
from effectful.ops.semantics import fwd
|
|
4
|
+
from effectful.ops.syntax import ObjectInterpretation, implements
|
|
5
|
+
from jax import vmap
|
|
6
|
+
|
|
7
|
+
from dynestyx.handlers import HandlesSelf, _sample_intp
|
|
8
|
+
from dynestyx.models import (
|
|
9
|
+
ContinuousTimeStateEvolution,
|
|
10
|
+
DiscreteTimeStateEvolution,
|
|
11
|
+
DynamicalModel,
|
|
12
|
+
)
|
|
13
|
+
from dynestyx.types import FunctionOfTime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _EulerMaruyamaDiscreteEvolution(DiscreteTimeStateEvolution):
|
|
17
|
+
"""x_{t+1} ~ N(x + drift*dt, (L@Q@L.T)*dt)."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, cte: ContinuousTimeStateEvolution):
|
|
20
|
+
self.cte = cte
|
|
21
|
+
|
|
22
|
+
def __call__(self, x, u, t_now, t_next):
|
|
23
|
+
"""
|
|
24
|
+
Discretize continuous-time state evolution via Euler-Maruyama. (CTSE) -> DTSE.
|
|
25
|
+
|
|
26
|
+
We step from t_now to t_next for each timepoint provided (optionally just 1 timepoint provided).
|
|
27
|
+
The main use case of providing multiple timepoints is when paired with DiracDeltaObservation that
|
|
28
|
+
allows temporal independence between observations, which allows us to step through all timepoints at once (creating big speedups).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
x: (dim_state,) or (dim_state, num_timepoints)
|
|
32
|
+
u: (dim_control,) or (dim_control, num_timepoints)
|
|
33
|
+
t_now: (1,) or (num_timepoints,)
|
|
34
|
+
t_next: (1,) or (num_timepoints,)
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
dist: MultivariateNormal distribution
|
|
38
|
+
- loc: (dim_state, num_timepoints) or (dim_state)
|
|
39
|
+
- covariance_matrix: (dim_state, dim_state, num_timepoints) or (dim_state, dim_state)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
squeezed = False
|
|
43
|
+
if x.ndim == 1:
|
|
44
|
+
squeezed = True
|
|
45
|
+
x = x[:, None] # (dim_state, 1) state
|
|
46
|
+
if u is not None:
|
|
47
|
+
if u.ndim == 1:
|
|
48
|
+
u = u[:, None] # (dim_control, 1) control
|
|
49
|
+
if t_now.ndim == 0:
|
|
50
|
+
t_now = t_now[None] # (1,) timepoint
|
|
51
|
+
if t_next.ndim == 0:
|
|
52
|
+
t_next = t_next[None] # (1,) timepoint
|
|
53
|
+
|
|
54
|
+
def _step(_x, _u, _t_now, _t_next):
|
|
55
|
+
_dt = _t_next - _t_now
|
|
56
|
+
drift = self.cte.total_drift(_x, _u, _t_now)
|
|
57
|
+
x_pred_mean = _x + drift * _dt
|
|
58
|
+
L = self.cte.diffusion_coefficient(_x, _u, _t_now)
|
|
59
|
+
if self.cte.bm_dim is None:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"ContinuousTimeStateEvolution.bm_dim is not set. "
|
|
62
|
+
"Construct dynamics via DynamicalModel before discretization."
|
|
63
|
+
)
|
|
64
|
+
Q = jnp.eye(self.cte.bm_dim)
|
|
65
|
+
x_pred_cov = L @ Q @ L.T * _dt
|
|
66
|
+
return x_pred_mean, x_pred_cov
|
|
67
|
+
|
|
68
|
+
if u is None:
|
|
69
|
+
loc, cov = vmap(_step, in_axes=(1, None, 0, 0))(x, None, t_now, t_next)
|
|
70
|
+
else:
|
|
71
|
+
loc, cov = vmap(_step, in_axes=(1, 1, 0, 0))(x, u, t_now, t_next)
|
|
72
|
+
|
|
73
|
+
# If we lifted from unbatched, return unbatched dist shapes
|
|
74
|
+
if squeezed:
|
|
75
|
+
loc = loc[0]
|
|
76
|
+
cov = cov[0]
|
|
77
|
+
|
|
78
|
+
return dist.MultivariateNormal(loc=loc, covariance_matrix=cov)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def euler_maruyama(cte: ContinuousTimeStateEvolution) -> DiscreteTimeStateEvolution:
|
|
82
|
+
"""Discretize continuous-time state evolution via Euler-Maruyama.
|
|
83
|
+
|
|
84
|
+
Euler-Maruyama is a first-order discrete approximation of a continuous-time state evolution.
|
|
85
|
+
It is popular, as it is simple and effective for simple models.
|
|
86
|
+
The resulting discrete-time state evolution is linear and Gaussian.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cte: `ContinuousTimeStateEvolution` to discretize.
|
|
90
|
+
Returns:
|
|
91
|
+
DiscreteTimeStateEvolution: The discretized state evolution.
|
|
92
|
+
|
|
93
|
+
Note:
|
|
94
|
+
No dt is passed; it is set to t_next - t_now in the __call__ method.
|
|
95
|
+
|
|
96
|
+
??? note "Algorithm Reference"
|
|
97
|
+
The Euler Maruyama is a first order discretization.
|
|
98
|
+
The resulting discret-time state evolution is approximated as
|
|
99
|
+
|
|
100
|
+
x_{t+1} ~ N(x_t + drift * delta_t, (L@Q@L.T)*delta_t)
|
|
101
|
+
|
|
102
|
+
where:
|
|
103
|
+
x_t is the current state
|
|
104
|
+
drift is the drift function
|
|
105
|
+
L is the diffusion coefficient
|
|
106
|
+
Q is the diffusion covariance
|
|
107
|
+
delta_t is the time step between timepoints (t_next - t_now)
|
|
108
|
+
|
|
109
|
+
This is the first-order Ito-Taylor approximation.
|
|
110
|
+
|
|
111
|
+
References:
|
|
112
|
+
- This is the first-order Ito-Taylor approximation, discussed in Chapter 9.2 of: Särkkä, S., & Solin, A. (2019).
|
|
113
|
+
Applied Stochastic Differential Equations. Cambridge University Press.
|
|
114
|
+
[Available Online](https://users.aalto.fi/~asolin/sde-book/sde-book.pdf).
|
|
115
|
+
"""
|
|
116
|
+
return _EulerMaruyamaDiscreteEvolution(cte)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Discretizer(ObjectInterpretation, HandlesSelf):
|
|
120
|
+
"""
|
|
121
|
+
Performs discretization of a continuous-time state evolution, converting it to a discrete-time state evolution.
|
|
122
|
+
|
|
123
|
+
A `Discretizer` object should be used as a context manager around a call to a model with a `dsx.sample(...)`
|
|
124
|
+
statement to discretize a continuous-time state evolution to a discrete-time state evolution. The `Discretizer`
|
|
125
|
+
should be at a lower (i.e. inner) level in the current context stack than any inference (e.g., `Filter` or `Simulator`)
|
|
126
|
+
objects.
|
|
127
|
+
|
|
128
|
+
??? example "Using a Euler Maruyama Discretizer"
|
|
129
|
+
```python
|
|
130
|
+
import dynestyx as dsx
|
|
131
|
+
from dynestyx.discretizers import Discretizer, euler_maruyama
|
|
132
|
+
from dynestyx.inference.filters import Filter, EKFConfig
|
|
133
|
+
from dynestyx.models import (
|
|
134
|
+
ContinuousTimeStateEvolution,
|
|
135
|
+
DiscreteTimeStateEvolution,
|
|
136
|
+
DynamicalModel,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def model_with_ctse(obs_times=None, obs_values=None):
|
|
140
|
+
dynamics = DynamicalModel(
|
|
141
|
+
control_dim=0,
|
|
142
|
+
initial_condition=dist.MultivariateNormal(
|
|
143
|
+
loc=jnp.zeros(state_dim),
|
|
144
|
+
covariance_matrix=jnp.eye(state_dim),
|
|
145
|
+
),
|
|
146
|
+
state_evolution=ContinuousTimeStateEvolution(
|
|
147
|
+
drift=lambda x, u, t: x,
|
|
148
|
+
diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim),
|
|
149
|
+
),
|
|
150
|
+
observation_model=lambda x, u, t: dist.MultivariateNormal(
|
|
151
|
+
x,
|
|
152
|
+
0.1**2 * jnp.eye(observation_dim),
|
|
153
|
+
),
|
|
154
|
+
)
|
|
155
|
+
return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
|
|
156
|
+
|
|
157
|
+
def discretized_data_conditioned_model():
|
|
158
|
+
# We use a discrete-time filter now
|
|
159
|
+
with Filter(filter_config=EKFConfig()):
|
|
160
|
+
with Discretizer(discretize=euler_maruyama):
|
|
161
|
+
return model_with_ctse(obs_times=obs_times, obs_values=obs_values)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
??? note "Algorithm Reference"
|
|
165
|
+
For an overview of discretization methods for SDEs, see Chapter 9 of: Särkkä, S., & Solin, A. (2019).
|
|
166
|
+
Applied Stochastic Differential Equations. Cambridge University Press.
|
|
167
|
+
[Available Online](https://users.aalto.fi/~asolin/sde-book/sde-book.pdf).
|
|
168
|
+
|
|
169
|
+
Attributes:
|
|
170
|
+
discretize: A callable that converts a continuous-time state evolution to a discrete-time state evolution. Defaults to euler_maruyama.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(self, discretize=euler_maruyama):
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.discretize = discretize
|
|
176
|
+
|
|
177
|
+
@implements(_sample_intp)
|
|
178
|
+
def _sample_ds(
|
|
179
|
+
self,
|
|
180
|
+
name: str,
|
|
181
|
+
dynamics: DynamicalModel,
|
|
182
|
+
*,
|
|
183
|
+
obs_times=None,
|
|
184
|
+
obs_values=None,
|
|
185
|
+
ctrl_times=None,
|
|
186
|
+
ctrl_values=None,
|
|
187
|
+
**kwargs,
|
|
188
|
+
) -> FunctionOfTime:
|
|
189
|
+
if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution):
|
|
190
|
+
discrete_evolution = self.discretize(dynamics.state_evolution)
|
|
191
|
+
dynamics = DynamicalModel(
|
|
192
|
+
initial_condition=dynamics.initial_condition,
|
|
193
|
+
state_evolution=discrete_evolution,
|
|
194
|
+
observation_model=dynamics.observation_model,
|
|
195
|
+
control_model=dynamics.control_model,
|
|
196
|
+
control_dim=dynamics.control_dim,
|
|
197
|
+
)
|
|
198
|
+
return fwd(
|
|
199
|
+
name,
|
|
200
|
+
dynamics,
|
|
201
|
+
obs_times=obs_times,
|
|
202
|
+
obs_values=obs_values,
|
|
203
|
+
ctrl_times=ctrl_times,
|
|
204
|
+
ctrl_values=ctrl_values,
|
|
205
|
+
**kwargs,
|
|
206
|
+
)
|