dynestyx 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. dynestyx/__init__.py +49 -0
  2. dynestyx/diagnostics/__init__.py +1 -0
  3. dynestyx/diagnostics/plotting_utils.py +434 -0
  4. dynestyx/discretizers.py +206 -0
  5. dynestyx/handlers.py +155 -0
  6. dynestyx/inference/__init__.py +7 -0
  7. dynestyx/inference/filter_configs.py +675 -0
  8. dynestyx/inference/filters.py +352 -0
  9. dynestyx/inference/hmm_filters.py +241 -0
  10. dynestyx/inference/integrations/__init__.py +3 -0
  11. dynestyx/inference/integrations/blackjax/__init__.py +5 -0
  12. dynestyx/inference/integrations/blackjax/mcmc.py +319 -0
  13. dynestyx/inference/integrations/cd_dynamax/__init__.py +1 -0
  14. dynestyx/inference/integrations/cd_dynamax/continuous.py +266 -0
  15. dynestyx/inference/integrations/cd_dynamax/discrete.py +316 -0
  16. dynestyx/inference/integrations/cd_dynamax/utils.py +440 -0
  17. dynestyx/inference/integrations/cuthbert/__init__.py +1 -0
  18. dynestyx/inference/integrations/cuthbert/discrete.py +458 -0
  19. dynestyx/inference/integrations/utils.py +83 -0
  20. dynestyx/inference/mcmc.py +158 -0
  21. dynestyx/inference/mcmc_configs.py +113 -0
  22. dynestyx/models/__init__.py +39 -0
  23. dynestyx/models/checkers.py +163 -0
  24. dynestyx/models/core.py +376 -0
  25. dynestyx/models/lti_dynamics.py +189 -0
  26. dynestyx/models/observations.py +112 -0
  27. dynestyx/models/state_evolution.py +147 -0
  28. dynestyx/simulators.py +1044 -0
  29. dynestyx/types.py +14 -0
  30. dynestyx/utils.py +223 -0
  31. dynestyx-0.0.3.dist-info/METADATA +76 -0
  32. dynestyx-0.0.3.dist-info/RECORD +34 -0
  33. dynestyx-0.0.3.dist-info/WHEEL +4 -0
  34. dynestyx-0.0.3.dist-info/licenses/LICENSE.md +202 -0
dynestyx/__init__.py ADDED
@@ -0,0 +1,49 @@
1
+ """Dynestyx package."""
2
+
3
+ from dynestyx.discretizers import Discretizer, euler_maruyama
4
+ from dynestyx.handlers import sample
5
+ from dynestyx.inference.filters import Filter
6
+ from dynestyx.models import (
7
+ ContinuousTimeStateEvolution,
8
+ DiracIdentityObservation,
9
+ DiscreteTimeStateEvolution,
10
+ DynamicalModel,
11
+ GaussianObservation,
12
+ GaussianStateEvolution,
13
+ LinearGaussianObservation,
14
+ LinearGaussianStateEvolution,
15
+ LTI_continuous,
16
+ LTI_discrete,
17
+ ObservationModel,
18
+ )
19
+ from dynestyx.simulators import (
20
+ DiscreteTimeSimulator,
21
+ ODESimulator,
22
+ SDESimulator,
23
+ Simulator,
24
+ )
25
+ from dynestyx.utils import flatten_draws
26
+
27
+ __all__ = [
28
+ "ContinuousTimeStateEvolution",
29
+ "DiscreteTimeStateEvolution",
30
+ "DynamicalModel",
31
+ "AffineDrift",
32
+ "LTI_continuous",
33
+ "LTI_discrete",
34
+ "LinearGaussianStateEvolution",
35
+ "GaussianStateEvolution",
36
+ "Discretizer",
37
+ "ObservationModel",
38
+ "Filter",
39
+ "flatten_draws",
40
+ "sample",
41
+ "DiracIdentityObservation",
42
+ "LinearGaussianObservation",
43
+ "GaussianObservation",
44
+ "DiscreteTimeSimulator",
45
+ "ODESimulator",
46
+ "SDESimulator",
47
+ "Simulator",
48
+ "euler_maruyama",
49
+ ]
@@ -0,0 +1 @@
1
+ """Diagnostics utilities for dynestyx."""
@@ -0,0 +1,434 @@
1
+ # HMM
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+
8
+ def plot_hmm_states_and_observations(
9
+ times,
10
+ x,
11
+ y,
12
+ state_cmap="tab10",
13
+ obs_cmap="Set1",
14
+ show_fig=False,
15
+ save_path=None,
16
+ obs_style="auto",
17
+ obs_marker="x",
18
+ ):
19
+ """
20
+ Plot latent discrete HMM states as colored background bands
21
+ with observed signals overlaid.
22
+
23
+ :param times: (T,) Time points
24
+ :param x: (T,) Discrete latent state indices (0..K-1)
25
+ :param y: (T,) or (T, N_obs) Observations
26
+ """
27
+
28
+ times = np.asarray(times)
29
+ x = np.asarray(x)
30
+ y = np.asarray(y)
31
+
32
+ T = len(times)
33
+ if x.shape[0] != T:
34
+ raise ValueError(f"`x` must have shape (T,), got {x.shape} with T={T}.")
35
+ if y.shape[0] != T:
36
+ raise ValueError(
37
+ f"`y` must have shape (T,) or (T, N_obs), got {y.shape} with T={T}."
38
+ )
39
+
40
+ # ---- Normalize observation shape ----
41
+ if y.ndim == 1:
42
+ y = y[:, None] # (T, 1)
43
+
44
+ N_obs = y.shape[1]
45
+
46
+ # ---- Discrete state labels (may not be 0..K-1) ----
47
+ state_values = np.unique(x)
48
+ K = int(state_values.size)
49
+ state_to_idx = {int(s): i for i, s in enumerate(state_values.tolist())}
50
+
51
+ # ---- Time "edges" for clean contiguous state bands ----
52
+ # For irregular sampling, use midpoints between times; extend at ends by half-step.
53
+ if T == 1:
54
+ dt = 1.0
55
+ edges = np.array([times[0] - 0.5 * dt, times[0] + 0.5 * dt])
56
+ else:
57
+ mids = 0.5 * (times[:-1] + times[1:])
58
+ left = times[0] - 0.5 * (times[1] - times[0])
59
+ right = times[-1] + 0.5 * (times[-1] - times[-2])
60
+ edges = np.concatenate(([left], mids, [right]))
61
+
62
+ # ---- Color maps ----
63
+ cmap_states = plt.cm.get_cmap(state_cmap, K)
64
+ state_colors = [cmap_states(k) for k in range(K)]
65
+
66
+ cmap_obs = plt.cm.get_cmap(obs_cmap, N_obs)
67
+ obs_colors = [cmap_obs(i) for i in range(N_obs)]
68
+
69
+ fig, ax = plt.subplots(figsize=(10, 4))
70
+
71
+ # ---- Draw state background as contiguous segments ----
72
+ def draw_state_blocks():
73
+ start = 0
74
+ for t in range(1, T + 1):
75
+ if t == T or x[t] != x[start]:
76
+ s_val = int(x[start])
77
+ k = state_to_idx[s_val]
78
+ ax.axvspan(
79
+ edges[start],
80
+ edges[t],
81
+ color=state_colors[k],
82
+ alpha=0.18,
83
+ linewidth=0,
84
+ )
85
+ start = t
86
+
87
+ draw_state_blocks()
88
+
89
+ # ---- Choose observation style ----
90
+ # If observations are discrete-valued, lines look misleading; default to scatter.
91
+ def _is_discrete_column(col: np.ndarray) -> bool:
92
+ if np.issubdtype(col.dtype, np.integer) or np.issubdtype(col.dtype, np.bool_):
93
+ return True
94
+ # Heuristic: "few unique values" relative to length suggests discrete categories.
95
+ # (Keeps continuous floats like SDE outputs as lines.)
96
+ unique = np.unique(col)
97
+ return unique.size <= min(20, max(3, T // 5))
98
+
99
+ if obs_style not in {"auto", "line", "scatter"}:
100
+ raise ValueError("`obs_style` must be one of {'auto','line','scatter'}.")
101
+
102
+ # ---- Plot observations ----
103
+ for n in range(N_obs):
104
+ col = y[:, n]
105
+ style = obs_style
106
+ if style == "auto":
107
+ style = "scatter" if _is_discrete_column(col) else "line"
108
+
109
+ if style == "line":
110
+ ax.plot(
111
+ times,
112
+ col,
113
+ color=obs_colors[n],
114
+ lw=2,
115
+ label=f"obs[{n}]",
116
+ zorder=5,
117
+ )
118
+ else:
119
+ ax.scatter(
120
+ times,
121
+ col,
122
+ color=obs_colors[n],
123
+ marker=obs_marker,
124
+ s=35,
125
+ linewidths=1.5,
126
+ label=f"obs[{n}]",
127
+ zorder=6,
128
+ )
129
+
130
+ # ---- Formatting ----
131
+ ax.set_xlabel("Time")
132
+ ax.set_ylabel("Observations")
133
+ ax.set_title("HMM latent states and observations")
134
+
135
+ ax.grid(True, alpha=0.3)
136
+ ax.legend(frameon=False)
137
+
138
+ # ---- Build state legend separately ----
139
+ from matplotlib.patches import Patch
140
+
141
+ state_patches = [
142
+ Patch(
143
+ facecolor=state_colors[state_to_idx[int(s)]],
144
+ alpha=0.3,
145
+ label=f"state {int(s)}",
146
+ )
147
+ for s in state_values
148
+ ]
149
+
150
+ ax.legend(
151
+ handles=state_patches + ax.get_legend_handles_labels()[0],
152
+ loc="upper left",
153
+ frameon=False,
154
+ )
155
+
156
+ plt.tight_layout()
157
+
158
+ if save_path is not None:
159
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
160
+ plt.close()
161
+ elif show_fig:
162
+ plt.show()
163
+
164
+ return fig, ax
165
+
166
+
167
+ def plot_continuous_states_and_partial_observations(
168
+ times, x, y, show_fig=False, save_path=None
169
+ ):
170
+ """
171
+ Plot continuous latent states with partial noisy observations.
172
+
173
+ :param times: (T,) Time points
174
+ :param x: (T, state_dim) Continuous latent states
175
+ :param y: (T, obs_dim) Observations
176
+ :param show_fig: Whether to show the figure
177
+ :param save_path: Optional path to save the figure
178
+ """
179
+ times = np.asarray(times)
180
+ x = np.asarray(jnp.asarray(x))
181
+ y = np.asarray(jnp.asarray(y))
182
+
183
+ T, num_x = x.shape
184
+ num_y = y.shape[1]
185
+
186
+ # Colors
187
+ state_color = "C0"
188
+ obs_color = "C2"
189
+
190
+ # Figure
191
+ fig, axes = plt.subplots(
192
+ num_x, 1, figsize=(10, 2.2 * num_x), sharex=True, constrained_layout=True
193
+ )
194
+
195
+ if num_x == 1:
196
+ axes = [axes]
197
+
198
+ # Plot
199
+ for i, ax in enumerate(axes):
200
+ # Latent state
201
+ is_first_state = i == 0
202
+ ax.plot(
203
+ times,
204
+ x[:, i],
205
+ color=state_color,
206
+ lw=2.0,
207
+ alpha=0.95,
208
+ label="Latent state" if is_first_state else None,
209
+ )
210
+
211
+ # Observations (assume first num_y states are observed)
212
+ if i < num_y:
213
+ is_first_obs = i == 0
214
+ ax.scatter(
215
+ times,
216
+ y[:, i],
217
+ s=28,
218
+ facecolors="none",
219
+ edgecolors=obs_color,
220
+ linewidth=1.0,
221
+ alpha=0.7,
222
+ zorder=3,
223
+ label="Observation" if is_first_obs else None,
224
+ )
225
+
226
+ ax.set_ylabel(f"x{i + 1}")
227
+ ax.grid(True, alpha=0.3)
228
+
229
+ axes[-1].set_xlabel("Time")
230
+
231
+ # Legend
232
+ axes[0].legend(loc="upper right", frameon=False, ncol=2)
233
+
234
+ plt.tight_layout()
235
+
236
+ if save_path is not None:
237
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
238
+ plt.close()
239
+ elif show_fig:
240
+ plt.show()
241
+
242
+ return fig, axes
243
+
244
+
245
+ def plot_drift_field(
246
+ f_true,
247
+ f_learned,
248
+ f_learned_sd=None,
249
+ x1_range=(-3.0, 3.0),
250
+ x2_range=(-3.2, 3.2),
251
+ num_points=50,
252
+ return_rmse=False,
253
+ relative_error=False,
254
+ trajectory=None,
255
+ trajectory_axes="error",
256
+ trajectory_color="red",
257
+ trajectory_lw=1.5,
258
+ trajectory_alpha=0.85,
259
+ ):
260
+ """
261
+ Plot true vs learned drift fields (2D state space).
262
+ Optionally include learned uncertainty (stddev) and/or overlay a data trajectory.
263
+
264
+ Args:
265
+ f_true: callable f(x) -> (2,) array, true drift function
266
+ f_learned: callable f(x) -> (2,) array, learned drift function
267
+ f_learned_sd: optional callable f(x) -> (2,) array (stddev per output dim)
268
+ x1_range: tuple (low, high) for x1 axis
269
+ x2_range: tuple (low, high) for x2 axis
270
+ num_points: number of grid points per axis
271
+ return_rmse: if True, return (fig, rmse)
272
+ relative_error: if True, plot relative error instead of absolute
273
+ trajectory: optional (T, 2) array of (x1, x2) points to overlay (e.g. data path)
274
+ trajectory_axes: "error" (overlay on error panels only) or "all"
275
+ trajectory_color: color for trajectory line
276
+ trajectory_lw: linewidth for trajectory
277
+ trajectory_alpha: alpha for trajectory
278
+
279
+ Returns:
280
+ fig, or (fig, rmse) if return_rmse is True
281
+ """
282
+ x1 = jnp.linspace(x1_range[0], x1_range[1], num_points)
283
+ x2 = jnp.linspace(x2_range[0], x2_range[1], num_points)
284
+ X1, X2 = jnp.meshgrid(x1, x2, indexing="ij")
285
+ grid_points = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
286
+
287
+ f_true_vals = jax.vmap(f_true)(grid_points)
288
+ f_learned_vals = jax.vmap(f_learned)(grid_points)
289
+
290
+ if f_learned_sd is not None:
291
+ f_learned_sd_vals = jax.vmap(f_learned_sd)(grid_points)
292
+ if f_learned_sd_vals.ndim == 3 and f_learned_sd_vals.shape[1] == 1:
293
+ f_learned_sd_vals = f_learned_sd_vals.squeeze(1)
294
+ f1_sd = np.asarray(f_learned_sd_vals[:, 0].reshape(num_points, num_points))
295
+ f2_sd = np.asarray(f_learned_sd_vals[:, 1].reshape(num_points, num_points))
296
+ else:
297
+ f1_sd = f2_sd = None
298
+
299
+ f1_true = np.asarray(f_true_vals[:, 0].reshape(num_points, num_points))
300
+ f2_true = np.asarray(f_true_vals[:, 1].reshape(num_points, num_points))
301
+ f1_learned = np.asarray(f_learned_vals[:, 0].reshape(num_points, num_points))
302
+ f2_learned = np.asarray(f_learned_vals[:, 1].reshape(num_points, num_points))
303
+
304
+ f1_err = np.abs(f1_learned - f1_true)
305
+ f2_err = np.abs(f2_learned - f2_true)
306
+ if relative_error:
307
+ f1_err /= np.abs(f1_true) + 1e-6
308
+ f2_err /= np.abs(f2_true) + 1e-6
309
+
310
+ vlim1 = float(np.max(np.abs(np.concatenate([f1_true.ravel(), f1_learned.ravel()]))))
311
+ vlim2 = float(np.max(np.abs(np.concatenate([f2_true.ravel(), f2_learned.ravel()]))))
312
+
313
+ ncols = 4 if f_learned_sd is not None else 3
314
+ fig, axes = plt.subplots(2, ncols, figsize=(5 * ncols, 8), constrained_layout=True)
315
+
316
+ im0 = axes[0, 0].imshow(
317
+ f1_true.T,
318
+ origin="lower",
319
+ extent=(*x1_range, *x2_range),
320
+ cmap="seismic",
321
+ vmin=-vlim1,
322
+ vmax=vlim1,
323
+ aspect="auto",
324
+ )
325
+ axes[0, 0].set_title("f1 true")
326
+ fig.colorbar(im0, ax=axes[0, 0], fraction=0.046, pad=0.04)
327
+
328
+ im1 = axes[0, 1].imshow(
329
+ f1_learned.T,
330
+ origin="lower",
331
+ extent=(*x1_range, *x2_range),
332
+ cmap="seismic",
333
+ vmin=-vlim1,
334
+ vmax=vlim1,
335
+ aspect="auto",
336
+ )
337
+ axes[0, 1].set_title("f1 learned")
338
+ fig.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
339
+
340
+ im2 = axes[0, 2].imshow(
341
+ f1_err.T,
342
+ origin="lower",
343
+ extent=(*x1_range, *x2_range),
344
+ cmap="viridis",
345
+ aspect="auto",
346
+ )
347
+ axes[0, 2].set_title("f1 error")
348
+ fig.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
349
+
350
+ if f1_sd is not None:
351
+ im3 = axes[0, 3].imshow(
352
+ f1_sd.T,
353
+ origin="lower",
354
+ extent=(*x1_range, *x2_range),
355
+ cmap="magma",
356
+ aspect="auto",
357
+ )
358
+ axes[0, 3].set_title("f1 stddev")
359
+ fig.colorbar(im3, ax=axes[0, 3], fraction=0.046, pad=0.04)
360
+
361
+ im4 = axes[1, 0].imshow(
362
+ f2_true.T,
363
+ origin="lower",
364
+ extent=(*x1_range, *x2_range),
365
+ cmap="seismic",
366
+ vmin=-vlim2,
367
+ vmax=vlim2,
368
+ aspect="auto",
369
+ )
370
+ axes[1, 0].set_title("f2 true")
371
+ fig.colorbar(im4, ax=axes[1, 0], fraction=0.046, pad=0.04)
372
+
373
+ im5 = axes[1, 1].imshow(
374
+ f2_learned.T,
375
+ origin="lower",
376
+ extent=(*x1_range, *x2_range),
377
+ cmap="seismic",
378
+ vmin=-vlim2,
379
+ vmax=vlim2,
380
+ aspect="auto",
381
+ )
382
+ axes[1, 1].set_title("f2 learned")
383
+ fig.colorbar(im5, ax=axes[1, 1], fraction=0.046, pad=0.04)
384
+
385
+ im6 = axes[1, 2].imshow(
386
+ f2_err.T,
387
+ origin="lower",
388
+ extent=(*x1_range, *x2_range),
389
+ cmap="viridis",
390
+ aspect="auto",
391
+ )
392
+ axes[1, 2].set_title("f2 error")
393
+ fig.colorbar(im6, ax=axes[1, 2], fraction=0.046, pad=0.04)
394
+
395
+ if f2_sd is not None:
396
+ im7 = axes[1, 3].imshow(
397
+ f2_sd.T,
398
+ origin="lower",
399
+ extent=(*x1_range, *x2_range),
400
+ cmap="magma",
401
+ aspect="auto",
402
+ )
403
+ axes[1, 3].set_title("f2 stddev")
404
+ fig.colorbar(im7, ax=axes[1, 3], fraction=0.046, pad=0.04)
405
+
406
+ for ax in axes.ravel():
407
+ ax.set_xlabel("x1")
408
+ ax.set_ylabel("x2")
409
+ ax.grid(False)
410
+
411
+ if trajectory is not None:
412
+ traj = np.asarray(trajectory)
413
+ if traj.ndim != 2 or traj.shape[1] != 2:
414
+ raise ValueError("trajectory must have shape (T, 2) for (x1, x2)")
415
+ if trajectory_axes == "error":
416
+ overlay_axes = [axes[0, 2], axes[1, 2]]
417
+ elif trajectory_axes == "all":
418
+ overlay_axes = list(axes.ravel())
419
+ else:
420
+ raise ValueError('trajectory_axes must be "error" or "all"')
421
+ for ax in overlay_axes:
422
+ ax.plot(
423
+ traj[:, 0],
424
+ traj[:, 1],
425
+ color=trajectory_color,
426
+ lw=trajectory_lw,
427
+ alpha=trajectory_alpha,
428
+ zorder=5,
429
+ )
430
+
431
+ if return_rmse:
432
+ rmse = float(jnp.sqrt(jnp.mean((f_learned_vals - f_true_vals) ** 2)))
433
+ return fig, rmse
434
+ return fig
@@ -0,0 +1,206 @@
1
+ import jax.numpy as jnp
2
+ import numpyro.distributions as dist
3
+ from effectful.ops.semantics import fwd
4
+ from effectful.ops.syntax import ObjectInterpretation, implements
5
+ from jax import vmap
6
+
7
+ from dynestyx.handlers import HandlesSelf, _sample_intp
8
+ from dynestyx.models import (
9
+ ContinuousTimeStateEvolution,
10
+ DiscreteTimeStateEvolution,
11
+ DynamicalModel,
12
+ )
13
+ from dynestyx.types import FunctionOfTime
14
+
15
+
16
+ class _EulerMaruyamaDiscreteEvolution(DiscreteTimeStateEvolution):
17
+ """x_{t+1} ~ N(x + drift*dt, (L@Q@L.T)*dt)."""
18
+
19
+ def __init__(self, cte: ContinuousTimeStateEvolution):
20
+ self.cte = cte
21
+
22
+ def __call__(self, x, u, t_now, t_next):
23
+ """
24
+ Discretize continuous-time state evolution via Euler-Maruyama. (CTSE) -> DTSE.
25
+
26
+ We step from t_now to t_next for each timepoint provided (optionally just 1 timepoint provided).
27
+ The main use case of providing multiple timepoints is when paired with DiracDeltaObservation that
28
+ allows temporal independence between observations, which allows us to step through all timepoints at once (creating big speedups).
29
+
30
+ Args:
31
+ x: (dim_state,) or (dim_state, num_timepoints)
32
+ u: (dim_control,) or (dim_control, num_timepoints)
33
+ t_now: (1,) or (num_timepoints,)
34
+ t_next: (1,) or (num_timepoints,)
35
+
36
+ Returns:
37
+ dist: MultivariateNormal distribution
38
+ - loc: (dim_state, num_timepoints) or (dim_state)
39
+ - covariance_matrix: (dim_state, dim_state, num_timepoints) or (dim_state, dim_state)
40
+ """
41
+
42
+ squeezed = False
43
+ if x.ndim == 1:
44
+ squeezed = True
45
+ x = x[:, None] # (dim_state, 1) state
46
+ if u is not None:
47
+ if u.ndim == 1:
48
+ u = u[:, None] # (dim_control, 1) control
49
+ if t_now.ndim == 0:
50
+ t_now = t_now[None] # (1,) timepoint
51
+ if t_next.ndim == 0:
52
+ t_next = t_next[None] # (1,) timepoint
53
+
54
+ def _step(_x, _u, _t_now, _t_next):
55
+ _dt = _t_next - _t_now
56
+ drift = self.cte.total_drift(_x, _u, _t_now)
57
+ x_pred_mean = _x + drift * _dt
58
+ L = self.cte.diffusion_coefficient(_x, _u, _t_now)
59
+ if self.cte.bm_dim is None:
60
+ raise ValueError(
61
+ "ContinuousTimeStateEvolution.bm_dim is not set. "
62
+ "Construct dynamics via DynamicalModel before discretization."
63
+ )
64
+ Q = jnp.eye(self.cte.bm_dim)
65
+ x_pred_cov = L @ Q @ L.T * _dt
66
+ return x_pred_mean, x_pred_cov
67
+
68
+ if u is None:
69
+ loc, cov = vmap(_step, in_axes=(1, None, 0, 0))(x, None, t_now, t_next)
70
+ else:
71
+ loc, cov = vmap(_step, in_axes=(1, 1, 0, 0))(x, u, t_now, t_next)
72
+
73
+ # If we lifted from unbatched, return unbatched dist shapes
74
+ if squeezed:
75
+ loc = loc[0]
76
+ cov = cov[0]
77
+
78
+ return dist.MultivariateNormal(loc=loc, covariance_matrix=cov)
79
+
80
+
81
+ def euler_maruyama(cte: ContinuousTimeStateEvolution) -> DiscreteTimeStateEvolution:
82
+ """Discretize continuous-time state evolution via Euler-Maruyama.
83
+
84
+ Euler-Maruyama is a first-order discrete approximation of a continuous-time state evolution.
85
+ It is popular, as it is simple and effective for simple models.
86
+ The resulting discrete-time state evolution is linear and Gaussian.
87
+
88
+ Args:
89
+ cte: `ContinuousTimeStateEvolution` to discretize.
90
+ Returns:
91
+ DiscreteTimeStateEvolution: The discretized state evolution.
92
+
93
+ Note:
94
+ No dt is passed; it is set to t_next - t_now in the __call__ method.
95
+
96
+ ??? note "Algorithm Reference"
97
+ The Euler Maruyama is a first order discretization.
98
+ The resulting discret-time state evolution is approximated as
99
+
100
+ x_{t+1} ~ N(x_t + drift * delta_t, (L@Q@L.T)*delta_t)
101
+
102
+ where:
103
+ x_t is the current state
104
+ drift is the drift function
105
+ L is the diffusion coefficient
106
+ Q is the diffusion covariance
107
+ delta_t is the time step between timepoints (t_next - t_now)
108
+
109
+ This is the first-order Ito-Taylor approximation.
110
+
111
+ References:
112
+ - This is the first-order Ito-Taylor approximation, discussed in Chapter 9.2 of: Särkkä, S., & Solin, A. (2019).
113
+ Applied Stochastic Differential Equations. Cambridge University Press.
114
+ [Available Online](https://users.aalto.fi/~asolin/sde-book/sde-book.pdf).
115
+ """
116
+ return _EulerMaruyamaDiscreteEvolution(cte)
117
+
118
+
119
+ class Discretizer(ObjectInterpretation, HandlesSelf):
120
+ """
121
+ Performs discretization of a continuous-time state evolution, converting it to a discrete-time state evolution.
122
+
123
+ A `Discretizer` object should be used as a context manager around a call to a model with a `dsx.sample(...)`
124
+ statement to discretize a continuous-time state evolution to a discrete-time state evolution. The `Discretizer`
125
+ should be at a lower (i.e. inner) level in the current context stack than any inference (e.g., `Filter` or `Simulator`)
126
+ objects.
127
+
128
+ ??? example "Using a Euler Maruyama Discretizer"
129
+ ```python
130
+ import dynestyx as dsx
131
+ from dynestyx.discretizers import Discretizer, euler_maruyama
132
+ from dynestyx.inference.filters import Filter, EKFConfig
133
+ from dynestyx.models import (
134
+ ContinuousTimeStateEvolution,
135
+ DiscreteTimeStateEvolution,
136
+ DynamicalModel,
137
+ )
138
+
139
+ def model_with_ctse(obs_times=None, obs_values=None):
140
+ dynamics = DynamicalModel(
141
+ control_dim=0,
142
+ initial_condition=dist.MultivariateNormal(
143
+ loc=jnp.zeros(state_dim),
144
+ covariance_matrix=jnp.eye(state_dim),
145
+ ),
146
+ state_evolution=ContinuousTimeStateEvolution(
147
+ drift=lambda x, u, t: x,
148
+ diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim),
149
+ ),
150
+ observation_model=lambda x, u, t: dist.MultivariateNormal(
151
+ x,
152
+ 0.1**2 * jnp.eye(observation_dim),
153
+ ),
154
+ )
155
+ return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
156
+
157
+ def discretized_data_conditioned_model():
158
+ # We use a discrete-time filter now
159
+ with Filter(filter_config=EKFConfig()):
160
+ with Discretizer(discretize=euler_maruyama):
161
+ return model_with_ctse(obs_times=obs_times, obs_values=obs_values)
162
+ ```
163
+
164
+ ??? note "Algorithm Reference"
165
+ For an overview of discretization methods for SDEs, see Chapter 9 of: Särkkä, S., & Solin, A. (2019).
166
+ Applied Stochastic Differential Equations. Cambridge University Press.
167
+ [Available Online](https://users.aalto.fi/~asolin/sde-book/sde-book.pdf).
168
+
169
+ Attributes:
170
+ discretize: A callable that converts a continuous-time state evolution to a discrete-time state evolution. Defaults to euler_maruyama.
171
+ """
172
+
173
+ def __init__(self, discretize=euler_maruyama):
174
+ super().__init__()
175
+ self.discretize = discretize
176
+
177
+ @implements(_sample_intp)
178
+ def _sample_ds(
179
+ self,
180
+ name: str,
181
+ dynamics: DynamicalModel,
182
+ *,
183
+ obs_times=None,
184
+ obs_values=None,
185
+ ctrl_times=None,
186
+ ctrl_values=None,
187
+ **kwargs,
188
+ ) -> FunctionOfTime:
189
+ if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution):
190
+ discrete_evolution = self.discretize(dynamics.state_evolution)
191
+ dynamics = DynamicalModel(
192
+ initial_condition=dynamics.initial_condition,
193
+ state_evolution=discrete_evolution,
194
+ observation_model=dynamics.observation_model,
195
+ control_model=dynamics.control_model,
196
+ control_dim=dynamics.control_dim,
197
+ )
198
+ return fwd(
199
+ name,
200
+ dynamics,
201
+ obs_times=obs_times,
202
+ obs_values=obs_values,
203
+ ctrl_times=ctrl_times,
204
+ ctrl_values=ctrl_values,
205
+ **kwargs,
206
+ )