sindy-exp 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sindy_exp/_plotting.py ADDED
@@ -0,0 +1,544 @@
1
+ from dataclasses import dataclass
2
+ from typing import Annotated, Optional, Sequence
3
+ from warnings import warn
4
+
5
+ import matplotlib as mpl
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import scipy
9
+ import seaborn as sns
10
+ import sympy as sp
11
+ from matplotlib.axes import Axes
12
+ from matplotlib.figure import Figure
13
+ from matplotlib.typing import ColorType
14
+
15
+ PAL = sns.color_palette("Set1")
16
+ PLOT_KWS = {"alpha": 0.7, "linewidth": 3}
17
+
18
+
19
+ @dataclass
20
+ class _ColorConstants:
21
+ color_sequence: list[ColorType]
22
+
23
+ def set_sequence(self, color_sequence: list[ColorType]):
24
+ self.color_sequence = color_sequence
25
+
26
+ @property
27
+ def TRUE(self):
28
+ return self.color_sequence[0]
29
+
30
+ @property
31
+ def MEAS(self):
32
+ return self.color_sequence[1]
33
+
34
+ @property
35
+ def EST(self):
36
+ return self.color_sequence[2]
37
+
38
+ @property
39
+ def SIM(self):
40
+ return self.color_sequence[3]
41
+
42
+ @property
43
+ def TRAIN(self):
44
+ return self.color_sequence[4]
45
+
46
+ @property
47
+ def TEST(self):
48
+ return self.color_sequence[5]
49
+
50
+
51
+ COLOR = _ColorConstants(mpl.color_sequences["tab10"])
52
+
53
+
54
+ def plot_coefficients(
55
+ coefficients: Annotated[np.ndarray, "(n_coord, n_features)"],
56
+ input_features: Sequence[str],
57
+ feature_names: Sequence[str],
58
+ ax: Axes,
59
+ **heatmap_kws,
60
+ ) -> None:
61
+ """Plot a set of dynamical system coefficients in a heatmap.
62
+
63
+ Args:
64
+ coefficients: A 2D array holding the coefficients of different
65
+ library functions. System dimension is rows, function index
66
+ is columns
67
+ input_features: system coordinate names, e.g. "x","y","z" or "u","v"
68
+ feature_names: the names of the functions in the library.
69
+ ax: the matplotlib axis to plot on
70
+ **heatmap_kws: additional kwargs to seaborn's styling
71
+ """
72
+
73
+ def detex(input: str) -> str:
74
+ if input[0] == "$":
75
+ input = input[1:]
76
+ if input[-1] == "$":
77
+ input = input[:-1]
78
+ return input
79
+
80
+ if input_features is None:
81
+ input_features = [r"$\dot x_" + f"{k}$" for k in range(coefficients.shape[0])]
82
+ else:
83
+ input_features = [r"$\dot " + f"{detex(fi)}$" for fi in input_features]
84
+
85
+ if feature_names is None:
86
+ feature_names = [f"f{k}" for k in range(coefficients.shape[1])]
87
+
88
+ with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}):
89
+ heatmap_args = {
90
+ "xticklabels": input_features,
91
+ "yticklabels": feature_names,
92
+ "center": 0.0,
93
+ "cmap": sns.color_palette("vlag", n_colors=20, as_cmap=True),
94
+ "ax": ax,
95
+ "linewidths": 0.1,
96
+ "linecolor": "whitesmoke",
97
+ }
98
+ heatmap_args.update(**heatmap_kws)
99
+ coefficients = np.where(
100
+ coefficients == 0, np.nan * np.empty_like(coefficients), coefficients
101
+ )
102
+ sns.heatmap(coefficients.T, **heatmap_args)
103
+
104
+ ax.tick_params(axis="y", rotation=0)
105
+
106
+ return ax
107
+
108
+
109
+ def _coeff_dicts_to_matrix(
110
+ coeffs: Sequence[dict[sp.Expr, float]],
111
+ ) -> tuple[np.ndarray, list[str]]:
112
+ """Convert a list of coefficient dictionaries to a dense matrix.
113
+
114
+ The input is a list of dictionaries mapping feature identifiers (
115
+ SymPy expressions) to coefficients. All dictionaries are assumed to
116
+ correspond to different coordinates of the same system. This helper builds
117
+ a consistent feature ordering across coordinates and returns a numeric
118
+ matrix along with the stringified feature names.
119
+ """
120
+
121
+ if not coeffs:
122
+ raise ValueError("No coefficient dictionaries provided.")
123
+
124
+ features: list[sp.Expr] = sorted({key for d in coeffs for key in d.keys()}, key=str)
125
+
126
+ mat = np.zeros((len(coeffs), len(features)), dtype=float)
127
+ for row, d in enumerate(coeffs):
128
+ for col, feat in enumerate(features):
129
+ mat[row, col] = d[feat]
130
+
131
+ feature_names = [str(f).replace("**", "^") for f in features]
132
+ return mat, feature_names
133
+
134
+
135
+ def _compare_coefficient_plots_impl(
136
+ coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"],
137
+ coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"],
138
+ input_features: Sequence[str],
139
+ feature_names: Sequence[str],
140
+ scaling: bool = True,
141
+ axs: Optional[Sequence[Axes]] = None,
142
+ ) -> None:
143
+ """Internal implementation for coefficient comparison heatmaps."""
144
+ n_cols = len(coefficients_est)
145
+
146
+ # helps boost the color of small coefficients. Maybe log is better?
147
+ all_vals = np.hstack((coefficients_est.flatten(), coefficients_true.flatten()))
148
+ nzs = all_vals[all_vals.nonzero()]
149
+ max_val = np.max(np.abs(nzs), initial=0.0)
150
+ min_val = np.min(np.abs(nzs), initial=np.inf)
151
+ if scaling and np.isfinite(min_val) and max_val / min_val > 10:
152
+ pwr_ratio = 1.0 / np.log10(max_val / min_val)
153
+ else:
154
+ pwr_ratio = 1
155
+
156
+ def signed_root(x):
157
+ return np.sign(x) * np.power(np.abs(x), pwr_ratio)
158
+
159
+ with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}):
160
+ if axs is None:
161
+ fig, axs = plt.subplots(
162
+ 1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True
163
+ )
164
+ fig.tight_layout()
165
+
166
+ vmax = signed_root(max_val)
167
+
168
+ plot_coefficients(
169
+ signed_root(coefficients_true),
170
+ input_features=input_features,
171
+ feature_names=feature_names,
172
+ ax=axs[0],
173
+ cbar=False,
174
+ vmax=vmax,
175
+ vmin=-vmax,
176
+ )
177
+
178
+ plot_coefficients(
179
+ signed_root(coefficients_est),
180
+ input_features=input_features,
181
+ feature_names=feature_names,
182
+ ax=axs[1],
183
+ cbar=False,
184
+ vmax=vmax,
185
+ vmin=-vmax,
186
+ )
187
+
188
+ axs[0].set_title("True Coefficients", rotation=45)
189
+ axs[1].set_title("Est. Coefficients", rotation=45)
190
+
191
+
192
+ def compare_coefficient_plots(
193
+ coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"],
194
+ coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"],
195
+ input_features: Sequence[str],
196
+ feature_names: Sequence[str],
197
+ scaling: bool = True,
198
+ axs: Optional[Sequence[Axes]] = None,
199
+ ) -> None:
200
+ """Create plots of true and estimated coefficients.
201
+
202
+ Deprecated:
203
+ Use :func:`compare_coefficient_plots_from_dicts` with coefficient
204
+ dictionaries instead. This function will be removed in a future
205
+ release.
206
+ """
207
+
208
+ warn(
209
+ "compare_coefficient_plots is deprecated; use "
210
+ "compare_coefficient_plots_from_dicts with coefficient dictionaries instead.",
211
+ DeprecationWarning,
212
+ stacklevel=2,
213
+ )
214
+
215
+ _compare_coefficient_plots_impl(
216
+ coefficients_est,
217
+ coefficients_true,
218
+ input_features=input_features,
219
+ feature_names=feature_names,
220
+ scaling=scaling,
221
+ axs=axs,
222
+ )
223
+
224
+
225
+ def compare_coefficient_plots_from_dicts(
226
+ coefficients_est: Sequence[dict[sp.Expr, float]],
227
+ coefficients_true: Sequence[dict[sp.Expr, float]],
228
+ input_features: Sequence[str],
229
+ feature_names: Sequence[str] | None = None,
230
+ scaling: bool = True,
231
+ axs: Optional[Sequence[Axes]] = None,
232
+ ):
233
+ """Wrapper to compare coefficients given as dictionaries.
234
+
235
+ Converts aligned coefficient dictionaries into dense matrices and then
236
+ delegates to :func:`compare_coefficient_plots` for plotting.
237
+
238
+ This assumes that the coefficient dictionaries are aligned, i.e., that they
239
+ contain the same keys across all coordinates, as produced by
240
+ ``unionize_coeff_dicts()``.
241
+ """
242
+
243
+ true_mat, inferred_feature_names = _coeff_dicts_to_matrix(coefficients_true)
244
+ est_mat, est_feature_names = _coeff_dicts_to_matrix(coefficients_est)
245
+
246
+ if true_mat.shape != est_mat.shape:
247
+ raise ValueError("True and estimated coefficient shapes do not match")
248
+
249
+ if inferred_feature_names != est_feature_names:
250
+ raise ValueError(
251
+ "Feature names inferred from true and estimated coefficients do not match"
252
+ )
253
+
254
+ _compare_coefficient_plots_impl(
255
+ est_mat,
256
+ true_mat,
257
+ input_features=input_features,
258
+ feature_names=inferred_feature_names,
259
+ scaling=scaling,
260
+ axs=axs,
261
+ )
262
+
263
+
264
+ def _plot_training_trajectory(
265
+ ax: Axes,
266
+ x_train: np.ndarray,
267
+ x_true: np.ndarray,
268
+ x_smooth: np.ndarray | None,
269
+ labels: bool = True,
270
+ ) -> None:
271
+ """Plot a single training trajectory
272
+
273
+ If x_smooth is provided, it is only plotted if sufficiently different
274
+ from x_train.
275
+ """
276
+ if x_train.shape[1] == 2:
277
+ ax.plot(
278
+ x_true[:, 0], x_true[:, 1], ".", label="True", color=COLOR.TRUE, **PLOT_KWS
279
+ )
280
+ ax.plot(
281
+ x_train[:, 0],
282
+ x_train[:, 1],
283
+ ".",
284
+ label="Measured",
285
+ color=COLOR.MEAS,
286
+ **PLOT_KWS,
287
+ )
288
+ if (
289
+ x_smooth is not None
290
+ and np.linalg.norm(x_smooth - x_train) / np.linalg.norm(x_train) > 1e-12
291
+ ):
292
+ ax.plot(
293
+ x_smooth[:, 0],
294
+ x_smooth[:, 1],
295
+ ".",
296
+ label="Smoothed",
297
+ color=COLOR.EST,
298
+ **PLOT_KWS,
299
+ )
300
+ if labels:
301
+ ax.set(xlabel="$x_0$", ylabel="$x_1$")
302
+ else:
303
+ ax.set(xticks=[], yticks=[])
304
+ elif x_train.shape[1] == 3:
305
+ ax.plot(
306
+ x_true[:, 0],
307
+ x_true[:, 1],
308
+ x_true[:, 2],
309
+ color=COLOR.TRUE,
310
+ label="True values",
311
+ **PLOT_KWS,
312
+ )
313
+
314
+ ax.plot(
315
+ x_train[:, 0],
316
+ x_train[:, 1],
317
+ x_train[:, 2],
318
+ ".",
319
+ color=COLOR.MEAS,
320
+ label="Measured values",
321
+ alpha=0.3,
322
+ )
323
+ if (
324
+ x_smooth is not None
325
+ and np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12
326
+ ):
327
+ ax.plot(
328
+ x_smooth[:, 0],
329
+ x_smooth[:, 1],
330
+ x_smooth[:, 2],
331
+ ".",
332
+ color=COLOR.EST,
333
+ label="Smoothed values",
334
+ alpha=0.3,
335
+ )
336
+ if labels:
337
+ ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$")
338
+ else:
339
+ ax.set(xticks=[], yticks=[], zticks=[])
340
+ else:
341
+ raise ValueError("Can only plot 2d or 3d data.")
342
+
343
+
344
+ def plot_training_data(
345
+ t_train: np.ndarray,
346
+ x_train: np.ndarray,
347
+ x_true: np.ndarray,
348
+ x_smooth: np.ndarray | None = None,
349
+ coord_names: Optional[Sequence[str]] = None,
350
+ ) -> tuple[Figure, Figure]:
351
+ """Plot training data (and smoothed training data, if different)."""
352
+ if coord_names is None:
353
+ coord_names = [f"$x_{i}$" for i in range(x_true.shape[1])]
354
+
355
+ fig_composite = plt.figure(figsize=(12, 6))
356
+ if x_train.shape[-1] == 2:
357
+ ax_traj = fig_composite.add_subplot(1, 2, 1)
358
+ elif x_train.shape[-1] == 3:
359
+ ax_traj = fig_composite.add_subplot(1, 2, 1, projection="3d")
360
+ else:
361
+ raise ValueError("Too many or too few coordinates to plot")
362
+ _plot_training_trajectory(ax_traj, x_train, x_true, x_smooth)
363
+ ax_traj.legend()
364
+ ax_traj.set(title="Trajectory Plot")
365
+ ax_psd = fig_composite.add_subplot(1, 2, 2)
366
+ _plot_data_psd(ax_psd, x_train, coord_names, traj_type="train")
367
+ _plot_data_psd(ax_psd, x_true, coord_names, traj_type="true")
368
+ ax_psd.set(title="Absolute Spectral Density")
369
+
370
+ n_coord = x_true.shape[-1]
371
+ fig_by_coord_1d = plt.figure(figsize=(n_coord * 4, 6))
372
+ for coord_ind, cname in enumerate(coord_names):
373
+ ax = fig_by_coord_1d.add_subplot(n_coord, 1, coord_ind + 1)
374
+ _plot_training_1d(ax, coord_ind, t_train, x_train, x_true, x_smooth, cname)
375
+
376
+ fig_by_coord_1d.axes[-1].legend()
377
+
378
+ return fig_composite, fig_by_coord_1d
379
+
380
+
381
+ def _plot_data_psd(
382
+ ax: Axes,
383
+ x: np.ndarray,
384
+ coord_names: Sequence[str],
385
+ traj_type: str = "train",
386
+ ):
387
+ """Plot the power spectral density of training data."""
388
+ if traj_type == "train":
389
+ color = COLOR.MEAS
390
+ elif traj_type == "sim":
391
+ color = COLOR.SIM
392
+ elif traj_type == "true":
393
+ color = COLOR.TRUE
394
+ elif traj_type == "smooth":
395
+ color = COLOR.EST
396
+ else:
397
+ raise ValueError(f"Unknown traj_type '{traj_type}'")
398
+ coord_names = [name + f" {traj_type}" for name in coord_names]
399
+ for coord, series in zip(coord_names, x.T):
400
+ ax.loglog(
401
+ np.abs(scipy.fft.rfft(series)) / np.sqrt(len(series)),
402
+ color=color,
403
+ label=coord,
404
+ )
405
+ ax.legend()
406
+ ax.set(xlabel="Wavenumber")
407
+ ax.set(ylabel="Magnitude")
408
+
409
+
410
+ def _plot_training_1d(
411
+ ax: Axes,
412
+ coord_ind: int,
413
+ t_train: np.ndarray,
414
+ x_train: np.ndarray,
415
+ x_true: np.ndarray,
416
+ x_smooth: Optional[np.ndarray],
417
+ coord_name: str,
418
+ ):
419
+ ax.plot(t_train, x_train[..., coord_ind], ".", color=COLOR.MEAS, label="measured")
420
+ ax.plot(t_train, x_true[..., coord_ind], "-", color=COLOR.TRUE, label="true")
421
+ if x_smooth is not None:
422
+ ax.plot(t_train, x_smooth[..., coord_ind], color=COLOR.EST, label="smoothed")
423
+ ax.set(xlabel="t", ylabel=coord_name)
424
+
425
+
426
+ def _plot_pde_training_data(last_train, last_train_true, smoothed_last_train):
427
+ """Plot training data (and smoothed training data, if different)."""
428
+ # 1D:
429
+ if len(last_train.shape) == 3:
430
+ fig, axs = plt.subplots(1, 3, figsize=(18, 6))
431
+ axs[0].imshow(last_train_true, vmin=0, vmax=last_train_true.max())
432
+ axs[0].set(title="True Data")
433
+ axs[1].imshow(last_train, vmin=0, vmax=last_train_true.max())
434
+ axs[1].set(title="Noisy Data")
435
+ axs[2].imshow(smoothed_last_train, vmin=0, vmax=last_train_true.max())
436
+ axs[2].set(title="Smoothed Data")
437
+ return plt.show()
438
+
439
+
440
+ def _plot_test_sim_data_1d_panel(
441
+ axs: Sequence[Axes],
442
+ x_true: Optional[np.ndarray],
443
+ x_sim: np.ndarray,
444
+ t_test: np.ndarray,
445
+ t_sim: np.ndarray,
446
+ coord_names: Sequence[str],
447
+ ) -> None:
448
+ for ordinate, ax in enumerate(axs):
449
+ if x_true is not None:
450
+ ax.plot(t_test, x_true[:, ordinate], color=COLOR.TRUE, label="True")
451
+ axs[ordinate].plot(
452
+ t_sim, x_sim[:, ordinate], "--", color=COLOR.SIM, label="Simulation"
453
+ )
454
+ axs[ordinate].set(xlabel="t", ylabel=coord_names[ordinate])
455
+
456
+
457
+ def _plot_test_sim_data_2d(
458
+ ax: Axes,
459
+ x_true: Optional[np.ndarray],
460
+ x_sim: np.ndarray,
461
+ labels: bool,
462
+ coord_names: Sequence[str],
463
+ ) -> None:
464
+ if x_true is not None:
465
+ ax.plot(x_true[:, 0], x_true[:, 1], color=COLOR.TRUE, label="True Values")
466
+ ax.plot(x_sim[:, 0], x_sim[:, 1], "--", color=COLOR.SIM, label="Simulation")
467
+ if labels:
468
+ ax.set(xlabel=coord_names[0], ylabel=coord_names[1])
469
+ else:
470
+ ax.set(xticks=[], yticks=[])
471
+
472
+
473
+ def _plot_test_sim_data_3d(
474
+ ax: Axes, x_vals: np.ndarray, label: Optional[str], coord_names: Sequence[str]
475
+ ):
476
+ if label == "True":
477
+ color = COLOR.TRUE
478
+ elif label == "Simulation":
479
+ color = COLOR.SIM
480
+ else:
481
+ color = None
482
+ ax.plot(x_vals[:, 0], x_vals[:, 1], x_vals[:, 2], color=color, label=label)
483
+ if label:
484
+ ax.set(xlabel=coord_names[0], ylabel=coord_names[1], zlabel=coord_names[2])
485
+ else:
486
+ ax.set(xticks=[], yticks=[], zticks=[])
487
+
488
+
489
+ def plot_test_trajectory(
490
+ x_true: np.ndarray,
491
+ x_sim: np.ndarray,
492
+ t_test: np.ndarray,
493
+ t_sim: np.ndarray,
494
+ figs: Optional[tuple[Figure, Figure]] = None,
495
+ coord_names: Optional[Sequence[str]] = None,
496
+ ) -> tuple[Figure, Figure]:
497
+ """Plot a test trajectory
498
+
499
+ Args:
500
+ last_test: a single trajectory of the system
501
+ model: a trained model to simulate and compare to test data
502
+ dt: the time interval in test data
503
+
504
+ Returns:
505
+ The sequence of axes used for the single-dimension time-series plots.
506
+ If ``axs`` is provided, the same sequence is returned.
507
+ """
508
+ if coord_names is None:
509
+ coord_names = [f"$x_{i}$" for i in range(x_true.shape[1])]
510
+ if not figs:
511
+ fig_by_coord_1d, axs_by_coord = plt.subplots(
512
+ x_true.shape[1], 1, sharex=True, figsize=(7, 9)
513
+ )
514
+ if x_true.shape[1] == 2:
515
+ fig_composite, axs_composite = plt.subplots(1, 2, figsize=(10, 4.5))
516
+ elif x_true.shape[1] == 3:
517
+ fig_composite, axs_composite = plt.subplots(
518
+ 1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"}
519
+ )
520
+ else:
521
+ raise ValueError("Can only plot 2d or 3d data.")
522
+ else:
523
+ fig_composite, fig_by_coord_1d = figs
524
+ axs_composite = fig_composite.axes
525
+ axs_by_coord = fig_by_coord_1d.axes
526
+
527
+ assert isinstance(axs_composite, list)
528
+ assert isinstance(axs_by_coord, list)
529
+ _plot_test_sim_data_1d_panel(axs_by_coord, None, x_sim, t_test, t_sim, coord_names)
530
+ axs_by_coord[-1].legend()
531
+ if x_true.shape[1] == 2:
532
+ _plot_test_sim_data_2d(
533
+ axs_composite[0], None, x_sim, labels=True, coord_names=coord_names
534
+ )
535
+ elif x_true.shape[1] == 3:
536
+ _plot_test_sim_data_3d(axs_composite[0], x_sim, "Simulation", coord_names)
537
+ axs_composite[0].legend()
538
+ _plot_data_psd(axs_composite[1], x_sim, coord_names, traj_type="sim")
539
+ if not figs:
540
+ fig_by_coord_1d.suptitle("Test Trajectories by Dimension")
541
+ fig_composite.suptitle("Full Test Trajectories")
542
+ axs_composite[0].set(title="true trajectory")
543
+ axs_composite[0].set(title="model simulation")
544
+ return fig_composite, fig_by_coord_1d
sindy_exp/_typing.py ADDED
@@ -0,0 +1,158 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Literal,
7
+ NamedTuple,
8
+ Optional,
9
+ Protocol,
10
+ TypeVar,
11
+ overload,
12
+ )
13
+
14
+ import numpy as np
15
+ import sympy as sp
16
+ from numpy.typing import NBitBase
17
+ from sympy import Expr
18
+ from typing_extensions import Self
19
+
20
+ NpFlt = np.dtype[np.floating]
21
+ Float1D = np.ndarray[tuple[int], NpFlt]
22
+ Float2D = np.ndarray[tuple[int, int], NpFlt]
23
+ Shape = TypeVar("Shape", bound=tuple[int, ...])
24
+ FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
25
+
26
+
27
+ TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
28
+
29
+
30
+ class _BaseSINDy(Protocol):
31
+ optimizer: Any
32
+ feature_library: Any
33
+ feature_names: list[str]
34
+
35
+ def fit(self, x: TrajectoryType, t: TrajectoryType, *args, **kwargs) -> Self: ...
36
+
37
+ def simulate(self, x0: np.ndarray, t: np.ndarray, **kwargs) -> np.ndarray: ...
38
+
39
+ def score(
40
+ self,
41
+ x: TrajectoryType,
42
+ t: TrajectoryType,
43
+ x_dot: TrajectoryType,
44
+ metric: Callable,
45
+ ) -> float: ...
46
+
47
+ def predict(self, x: np.ndarray, u: None | np.ndarray = None) -> np.ndarray: ...
48
+
49
+ def coefficients(self): ...
50
+
51
+ @overload
52
+ def equations(self) -> list[str]: ...
53
+
54
+ @overload
55
+ def equations(self, precision: int) -> list[str]: ...
56
+
57
+ @overload
58
+ def equations(self, precision: int, fmt: Literal["str"] | None) -> list[str]: ...
59
+
60
+ @overload
61
+ def equations(
62
+ self, precision: int, fmt: Literal["sympy"]
63
+ ) -> list[dict[Expr, float]]: ...
64
+
65
+ def print(self, precision: int, **kwargs) -> None: ...
66
+
67
+ def get_feature_names(self) -> list[str]: ...
68
+
69
+
70
+ class ProbData(NamedTuple):
71
+ """Data bundle for a single trajectory.
72
+
73
+ Represents a trajectory's training data and associated metadata.
74
+ """
75
+
76
+ dt: float
77
+ t_train: Float1D
78
+ x_train: Float2D
79
+ x_train_true: Float2D
80
+ x_train_true_dot: Float2D
81
+ input_features: list[str]
82
+ integrator: Optional[Any] = None # diffrax.Solution
83
+
84
+
85
+ class NestedDict(defaultdict):
86
+ """A dictionary that splits all keys by ".", creating a sub-dict.
87
+
88
+ Args: see superclass
89
+
90
+ Example:
91
+
92
+ >>> foo = NestedDict("a.b"=1)
93
+ >>> foo["a.c"] = 2
94
+ >>> foo["a"]["b"]
95
+ 1
96
+ """
97
+
98
+ def __missing__(self, key):
99
+ try:
100
+ prefix, subkey = key.split(".", 1)
101
+ except ValueError:
102
+ raise KeyError(key)
103
+ return self[prefix][subkey]
104
+
105
+ def __setitem__(self, key, value):
106
+ if "." in key:
107
+ prefix, suffix = key.split(".", 1)
108
+ if self.get(prefix) is None:
109
+ self[prefix] = NestedDict()
110
+ return self[prefix].__setitem__(suffix, value)
111
+ else:
112
+ return super().__setitem__(key, value)
113
+
114
+ def update(self, other: dict): # type: ignore
115
+ try:
116
+ for k, v in other.items():
117
+ self.__setitem__(k, v)
118
+ except: # noqa: E722
119
+ super().update(other)
120
+
121
+ def flatten(self):
122
+ """Flattens a nested dictionary without mutating. Returns new dict"""
123
+
124
+ def _flatten(nested_d: dict) -> dict:
125
+ new = {}
126
+ for key, value in nested_d.items():
127
+ if not isinstance(key, str):
128
+ raise TypeError("Only string keys allowed in flattening")
129
+ if not isinstance(value, dict):
130
+ new[key] = value
131
+ continue
132
+ for sub_key, sub_value in _flatten(value).items():
133
+ new[key + "." + sub_key] = sub_value
134
+ return new
135
+
136
+ return _flatten(self)
137
+
138
+
139
+ @dataclass
140
+ class DynamicsTrialData:
141
+ trajectories: list[ProbData]
142
+ true_equations: list[dict[sp.Expr, float]]
143
+ sindy_equations: list[dict[sp.Expr, float]]
144
+ model: _BaseSINDy
145
+ input_features: list[str]
146
+ smooth_train: list[np.ndarray]
147
+
148
+
149
+ @dataclass
150
+ class SINDyTrialUpdate:
151
+ t_sim: Float1D
152
+ t_test: Float1D
153
+ x_sim: FloatND
154
+
155
+
156
+ @dataclass
157
+ class FullDynamicsTrialData(DynamicsTrialData):
158
+ sims: list[SINDyTrialUpdate]