sindy-exp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sindy_exp/__init__.py +28 -0
- sindy_exp/_data.py +202 -0
- sindy_exp/_diffrax_solver.py +104 -0
- sindy_exp/_dysts_to_sympy.py +452 -0
- sindy_exp/_odes.py +287 -0
- sindy_exp/_plotting.py +544 -0
- sindy_exp/_typing.py +158 -0
- sindy_exp/_utils.py +381 -0
- sindy_exp/addl_attractors.json +91 -0
- sindy_exp-0.2.0.dist-info/METADATA +111 -0
- sindy_exp-0.2.0.dist-info/RECORD +14 -0
- sindy_exp-0.2.0.dist-info/WHEEL +5 -0
- sindy_exp-0.2.0.dist-info/licenses/LICENSE +21 -0
- sindy_exp-0.2.0.dist-info/top_level.txt +1 -0
sindy_exp/_plotting.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Annotated, Optional, Sequence
|
|
3
|
+
from warnings import warn
|
|
4
|
+
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import scipy
|
|
9
|
+
import seaborn as sns
|
|
10
|
+
import sympy as sp
|
|
11
|
+
from matplotlib.axes import Axes
|
|
12
|
+
from matplotlib.figure import Figure
|
|
13
|
+
from matplotlib.typing import ColorType
|
|
14
|
+
|
|
15
|
+
PAL = sns.color_palette("Set1")
|
|
16
|
+
PLOT_KWS = {"alpha": 0.7, "linewidth": 3}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class _ColorConstants:
|
|
21
|
+
color_sequence: list[ColorType]
|
|
22
|
+
|
|
23
|
+
def set_sequence(self, color_sequence: list[ColorType]):
|
|
24
|
+
self.color_sequence = color_sequence
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def TRUE(self):
|
|
28
|
+
return self.color_sequence[0]
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def MEAS(self):
|
|
32
|
+
return self.color_sequence[1]
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def EST(self):
|
|
36
|
+
return self.color_sequence[2]
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def SIM(self):
|
|
40
|
+
return self.color_sequence[3]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def TRAIN(self):
|
|
44
|
+
return self.color_sequence[4]
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def TEST(self):
|
|
48
|
+
return self.color_sequence[5]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
COLOR = _ColorConstants(mpl.color_sequences["tab10"])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def plot_coefficients(
|
|
55
|
+
coefficients: Annotated[np.ndarray, "(n_coord, n_features)"],
|
|
56
|
+
input_features: Sequence[str],
|
|
57
|
+
feature_names: Sequence[str],
|
|
58
|
+
ax: Axes,
|
|
59
|
+
**heatmap_kws,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Plot a set of dynamical system coefficients in a heatmap.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
coefficients: A 2D array holding the coefficients of different
|
|
65
|
+
library functions. System dimension is rows, function index
|
|
66
|
+
is columns
|
|
67
|
+
input_features: system coordinate names, e.g. "x","y","z" or "u","v"
|
|
68
|
+
feature_names: the names of the functions in the library.
|
|
69
|
+
ax: the matplotlib axis to plot on
|
|
70
|
+
**heatmap_kws: additional kwargs to seaborn's styling
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def detex(input: str) -> str:
|
|
74
|
+
if input[0] == "$":
|
|
75
|
+
input = input[1:]
|
|
76
|
+
if input[-1] == "$":
|
|
77
|
+
input = input[:-1]
|
|
78
|
+
return input
|
|
79
|
+
|
|
80
|
+
if input_features is None:
|
|
81
|
+
input_features = [r"$\dot x_" + f"{k}$" for k in range(coefficients.shape[0])]
|
|
82
|
+
else:
|
|
83
|
+
input_features = [r"$\dot " + f"{detex(fi)}$" for fi in input_features]
|
|
84
|
+
|
|
85
|
+
if feature_names is None:
|
|
86
|
+
feature_names = [f"f{k}" for k in range(coefficients.shape[1])]
|
|
87
|
+
|
|
88
|
+
with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}):
|
|
89
|
+
heatmap_args = {
|
|
90
|
+
"xticklabels": input_features,
|
|
91
|
+
"yticklabels": feature_names,
|
|
92
|
+
"center": 0.0,
|
|
93
|
+
"cmap": sns.color_palette("vlag", n_colors=20, as_cmap=True),
|
|
94
|
+
"ax": ax,
|
|
95
|
+
"linewidths": 0.1,
|
|
96
|
+
"linecolor": "whitesmoke",
|
|
97
|
+
}
|
|
98
|
+
heatmap_args.update(**heatmap_kws)
|
|
99
|
+
coefficients = np.where(
|
|
100
|
+
coefficients == 0, np.nan * np.empty_like(coefficients), coefficients
|
|
101
|
+
)
|
|
102
|
+
sns.heatmap(coefficients.T, **heatmap_args)
|
|
103
|
+
|
|
104
|
+
ax.tick_params(axis="y", rotation=0)
|
|
105
|
+
|
|
106
|
+
return ax
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _coeff_dicts_to_matrix(
|
|
110
|
+
coeffs: Sequence[dict[sp.Expr, float]],
|
|
111
|
+
) -> tuple[np.ndarray, list[str]]:
|
|
112
|
+
"""Convert a list of coefficient dictionaries to a dense matrix.
|
|
113
|
+
|
|
114
|
+
The input is a list of dictionaries mapping feature identifiers (
|
|
115
|
+
SymPy expressions) to coefficients. All dictionaries are assumed to
|
|
116
|
+
correspond to different coordinates of the same system. This helper builds
|
|
117
|
+
a consistent feature ordering across coordinates and returns a numeric
|
|
118
|
+
matrix along with the stringified feature names.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
if not coeffs:
|
|
122
|
+
raise ValueError("No coefficient dictionaries provided.")
|
|
123
|
+
|
|
124
|
+
features: list[sp.Expr] = sorted({key for d in coeffs for key in d.keys()}, key=str)
|
|
125
|
+
|
|
126
|
+
mat = np.zeros((len(coeffs), len(features)), dtype=float)
|
|
127
|
+
for row, d in enumerate(coeffs):
|
|
128
|
+
for col, feat in enumerate(features):
|
|
129
|
+
mat[row, col] = d[feat]
|
|
130
|
+
|
|
131
|
+
feature_names = [str(f).replace("**", "^") for f in features]
|
|
132
|
+
return mat, feature_names
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _compare_coefficient_plots_impl(
|
|
136
|
+
coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"],
|
|
137
|
+
coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"],
|
|
138
|
+
input_features: Sequence[str],
|
|
139
|
+
feature_names: Sequence[str],
|
|
140
|
+
scaling: bool = True,
|
|
141
|
+
axs: Optional[Sequence[Axes]] = None,
|
|
142
|
+
) -> None:
|
|
143
|
+
"""Internal implementation for coefficient comparison heatmaps."""
|
|
144
|
+
n_cols = len(coefficients_est)
|
|
145
|
+
|
|
146
|
+
# helps boost the color of small coefficients. Maybe log is better?
|
|
147
|
+
all_vals = np.hstack((coefficients_est.flatten(), coefficients_true.flatten()))
|
|
148
|
+
nzs = all_vals[all_vals.nonzero()]
|
|
149
|
+
max_val = np.max(np.abs(nzs), initial=0.0)
|
|
150
|
+
min_val = np.min(np.abs(nzs), initial=np.inf)
|
|
151
|
+
if scaling and np.isfinite(min_val) and max_val / min_val > 10:
|
|
152
|
+
pwr_ratio = 1.0 / np.log10(max_val / min_val)
|
|
153
|
+
else:
|
|
154
|
+
pwr_ratio = 1
|
|
155
|
+
|
|
156
|
+
def signed_root(x):
|
|
157
|
+
return np.sign(x) * np.power(np.abs(x), pwr_ratio)
|
|
158
|
+
|
|
159
|
+
with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}):
|
|
160
|
+
if axs is None:
|
|
161
|
+
fig, axs = plt.subplots(
|
|
162
|
+
1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True
|
|
163
|
+
)
|
|
164
|
+
fig.tight_layout()
|
|
165
|
+
|
|
166
|
+
vmax = signed_root(max_val)
|
|
167
|
+
|
|
168
|
+
plot_coefficients(
|
|
169
|
+
signed_root(coefficients_true),
|
|
170
|
+
input_features=input_features,
|
|
171
|
+
feature_names=feature_names,
|
|
172
|
+
ax=axs[0],
|
|
173
|
+
cbar=False,
|
|
174
|
+
vmax=vmax,
|
|
175
|
+
vmin=-vmax,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
plot_coefficients(
|
|
179
|
+
signed_root(coefficients_est),
|
|
180
|
+
input_features=input_features,
|
|
181
|
+
feature_names=feature_names,
|
|
182
|
+
ax=axs[1],
|
|
183
|
+
cbar=False,
|
|
184
|
+
vmax=vmax,
|
|
185
|
+
vmin=-vmax,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
axs[0].set_title("True Coefficients", rotation=45)
|
|
189
|
+
axs[1].set_title("Est. Coefficients", rotation=45)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def compare_coefficient_plots(
|
|
193
|
+
coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"],
|
|
194
|
+
coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"],
|
|
195
|
+
input_features: Sequence[str],
|
|
196
|
+
feature_names: Sequence[str],
|
|
197
|
+
scaling: bool = True,
|
|
198
|
+
axs: Optional[Sequence[Axes]] = None,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Create plots of true and estimated coefficients.
|
|
201
|
+
|
|
202
|
+
Deprecated:
|
|
203
|
+
Use :func:`compare_coefficient_plots_from_dicts` with coefficient
|
|
204
|
+
dictionaries instead. This function will be removed in a future
|
|
205
|
+
release.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
warn(
|
|
209
|
+
"compare_coefficient_plots is deprecated; use "
|
|
210
|
+
"compare_coefficient_plots_from_dicts with coefficient dictionaries instead.",
|
|
211
|
+
DeprecationWarning,
|
|
212
|
+
stacklevel=2,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
_compare_coefficient_plots_impl(
|
|
216
|
+
coefficients_est,
|
|
217
|
+
coefficients_true,
|
|
218
|
+
input_features=input_features,
|
|
219
|
+
feature_names=feature_names,
|
|
220
|
+
scaling=scaling,
|
|
221
|
+
axs=axs,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def compare_coefficient_plots_from_dicts(
|
|
226
|
+
coefficients_est: Sequence[dict[sp.Expr, float]],
|
|
227
|
+
coefficients_true: Sequence[dict[sp.Expr, float]],
|
|
228
|
+
input_features: Sequence[str],
|
|
229
|
+
feature_names: Sequence[str] | None = None,
|
|
230
|
+
scaling: bool = True,
|
|
231
|
+
axs: Optional[Sequence[Axes]] = None,
|
|
232
|
+
):
|
|
233
|
+
"""Wrapper to compare coefficients given as dictionaries.
|
|
234
|
+
|
|
235
|
+
Converts aligned coefficient dictionaries into dense matrices and then
|
|
236
|
+
delegates to :func:`compare_coefficient_plots` for plotting.
|
|
237
|
+
|
|
238
|
+
This assumes that the coefficient dictionaries are aligned, i.e., that they
|
|
239
|
+
contain the same keys across all coordinates, as produced by
|
|
240
|
+
``unionize_coeff_dicts()``.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
true_mat, inferred_feature_names = _coeff_dicts_to_matrix(coefficients_true)
|
|
244
|
+
est_mat, est_feature_names = _coeff_dicts_to_matrix(coefficients_est)
|
|
245
|
+
|
|
246
|
+
if true_mat.shape != est_mat.shape:
|
|
247
|
+
raise ValueError("True and estimated coefficient shapes do not match")
|
|
248
|
+
|
|
249
|
+
if inferred_feature_names != est_feature_names:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
"Feature names inferred from true and estimated coefficients do not match"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
_compare_coefficient_plots_impl(
|
|
255
|
+
est_mat,
|
|
256
|
+
true_mat,
|
|
257
|
+
input_features=input_features,
|
|
258
|
+
feature_names=inferred_feature_names,
|
|
259
|
+
scaling=scaling,
|
|
260
|
+
axs=axs,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _plot_training_trajectory(
|
|
265
|
+
ax: Axes,
|
|
266
|
+
x_train: np.ndarray,
|
|
267
|
+
x_true: np.ndarray,
|
|
268
|
+
x_smooth: np.ndarray | None,
|
|
269
|
+
labels: bool = True,
|
|
270
|
+
) -> None:
|
|
271
|
+
"""Plot a single training trajectory
|
|
272
|
+
|
|
273
|
+
If x_smooth is provided, it is only plotted if sufficiently different
|
|
274
|
+
from x_train.
|
|
275
|
+
"""
|
|
276
|
+
if x_train.shape[1] == 2:
|
|
277
|
+
ax.plot(
|
|
278
|
+
x_true[:, 0], x_true[:, 1], ".", label="True", color=COLOR.TRUE, **PLOT_KWS
|
|
279
|
+
)
|
|
280
|
+
ax.plot(
|
|
281
|
+
x_train[:, 0],
|
|
282
|
+
x_train[:, 1],
|
|
283
|
+
".",
|
|
284
|
+
label="Measured",
|
|
285
|
+
color=COLOR.MEAS,
|
|
286
|
+
**PLOT_KWS,
|
|
287
|
+
)
|
|
288
|
+
if (
|
|
289
|
+
x_smooth is not None
|
|
290
|
+
and np.linalg.norm(x_smooth - x_train) / np.linalg.norm(x_train) > 1e-12
|
|
291
|
+
):
|
|
292
|
+
ax.plot(
|
|
293
|
+
x_smooth[:, 0],
|
|
294
|
+
x_smooth[:, 1],
|
|
295
|
+
".",
|
|
296
|
+
label="Smoothed",
|
|
297
|
+
color=COLOR.EST,
|
|
298
|
+
**PLOT_KWS,
|
|
299
|
+
)
|
|
300
|
+
if labels:
|
|
301
|
+
ax.set(xlabel="$x_0$", ylabel="$x_1$")
|
|
302
|
+
else:
|
|
303
|
+
ax.set(xticks=[], yticks=[])
|
|
304
|
+
elif x_train.shape[1] == 3:
|
|
305
|
+
ax.plot(
|
|
306
|
+
x_true[:, 0],
|
|
307
|
+
x_true[:, 1],
|
|
308
|
+
x_true[:, 2],
|
|
309
|
+
color=COLOR.TRUE,
|
|
310
|
+
label="True values",
|
|
311
|
+
**PLOT_KWS,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
ax.plot(
|
|
315
|
+
x_train[:, 0],
|
|
316
|
+
x_train[:, 1],
|
|
317
|
+
x_train[:, 2],
|
|
318
|
+
".",
|
|
319
|
+
color=COLOR.MEAS,
|
|
320
|
+
label="Measured values",
|
|
321
|
+
alpha=0.3,
|
|
322
|
+
)
|
|
323
|
+
if (
|
|
324
|
+
x_smooth is not None
|
|
325
|
+
and np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12
|
|
326
|
+
):
|
|
327
|
+
ax.plot(
|
|
328
|
+
x_smooth[:, 0],
|
|
329
|
+
x_smooth[:, 1],
|
|
330
|
+
x_smooth[:, 2],
|
|
331
|
+
".",
|
|
332
|
+
color=COLOR.EST,
|
|
333
|
+
label="Smoothed values",
|
|
334
|
+
alpha=0.3,
|
|
335
|
+
)
|
|
336
|
+
if labels:
|
|
337
|
+
ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$")
|
|
338
|
+
else:
|
|
339
|
+
ax.set(xticks=[], yticks=[], zticks=[])
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError("Can only plot 2d or 3d data.")
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def plot_training_data(
|
|
345
|
+
t_train: np.ndarray,
|
|
346
|
+
x_train: np.ndarray,
|
|
347
|
+
x_true: np.ndarray,
|
|
348
|
+
x_smooth: np.ndarray | None = None,
|
|
349
|
+
coord_names: Optional[Sequence[str]] = None,
|
|
350
|
+
) -> tuple[Figure, Figure]:
|
|
351
|
+
"""Plot training data (and smoothed training data, if different)."""
|
|
352
|
+
if coord_names is None:
|
|
353
|
+
coord_names = [f"$x_{i}$" for i in range(x_true.shape[1])]
|
|
354
|
+
|
|
355
|
+
fig_composite = plt.figure(figsize=(12, 6))
|
|
356
|
+
if x_train.shape[-1] == 2:
|
|
357
|
+
ax_traj = fig_composite.add_subplot(1, 2, 1)
|
|
358
|
+
elif x_train.shape[-1] == 3:
|
|
359
|
+
ax_traj = fig_composite.add_subplot(1, 2, 1, projection="3d")
|
|
360
|
+
else:
|
|
361
|
+
raise ValueError("Too many or too few coordinates to plot")
|
|
362
|
+
_plot_training_trajectory(ax_traj, x_train, x_true, x_smooth)
|
|
363
|
+
ax_traj.legend()
|
|
364
|
+
ax_traj.set(title="Trajectory Plot")
|
|
365
|
+
ax_psd = fig_composite.add_subplot(1, 2, 2)
|
|
366
|
+
_plot_data_psd(ax_psd, x_train, coord_names, traj_type="train")
|
|
367
|
+
_plot_data_psd(ax_psd, x_true, coord_names, traj_type="true")
|
|
368
|
+
ax_psd.set(title="Absolute Spectral Density")
|
|
369
|
+
|
|
370
|
+
n_coord = x_true.shape[-1]
|
|
371
|
+
fig_by_coord_1d = plt.figure(figsize=(n_coord * 4, 6))
|
|
372
|
+
for coord_ind, cname in enumerate(coord_names):
|
|
373
|
+
ax = fig_by_coord_1d.add_subplot(n_coord, 1, coord_ind + 1)
|
|
374
|
+
_plot_training_1d(ax, coord_ind, t_train, x_train, x_true, x_smooth, cname)
|
|
375
|
+
|
|
376
|
+
fig_by_coord_1d.axes[-1].legend()
|
|
377
|
+
|
|
378
|
+
return fig_composite, fig_by_coord_1d
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _plot_data_psd(
|
|
382
|
+
ax: Axes,
|
|
383
|
+
x: np.ndarray,
|
|
384
|
+
coord_names: Sequence[str],
|
|
385
|
+
traj_type: str = "train",
|
|
386
|
+
):
|
|
387
|
+
"""Plot the power spectral density of training data."""
|
|
388
|
+
if traj_type == "train":
|
|
389
|
+
color = COLOR.MEAS
|
|
390
|
+
elif traj_type == "sim":
|
|
391
|
+
color = COLOR.SIM
|
|
392
|
+
elif traj_type == "true":
|
|
393
|
+
color = COLOR.TRUE
|
|
394
|
+
elif traj_type == "smooth":
|
|
395
|
+
color = COLOR.EST
|
|
396
|
+
else:
|
|
397
|
+
raise ValueError(f"Unknown traj_type '{traj_type}'")
|
|
398
|
+
coord_names = [name + f" {traj_type}" for name in coord_names]
|
|
399
|
+
for coord, series in zip(coord_names, x.T):
|
|
400
|
+
ax.loglog(
|
|
401
|
+
np.abs(scipy.fft.rfft(series)) / np.sqrt(len(series)),
|
|
402
|
+
color=color,
|
|
403
|
+
label=coord,
|
|
404
|
+
)
|
|
405
|
+
ax.legend()
|
|
406
|
+
ax.set(xlabel="Wavenumber")
|
|
407
|
+
ax.set(ylabel="Magnitude")
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _plot_training_1d(
|
|
411
|
+
ax: Axes,
|
|
412
|
+
coord_ind: int,
|
|
413
|
+
t_train: np.ndarray,
|
|
414
|
+
x_train: np.ndarray,
|
|
415
|
+
x_true: np.ndarray,
|
|
416
|
+
x_smooth: Optional[np.ndarray],
|
|
417
|
+
coord_name: str,
|
|
418
|
+
):
|
|
419
|
+
ax.plot(t_train, x_train[..., coord_ind], ".", color=COLOR.MEAS, label="measured")
|
|
420
|
+
ax.plot(t_train, x_true[..., coord_ind], "-", color=COLOR.TRUE, label="true")
|
|
421
|
+
if x_smooth is not None:
|
|
422
|
+
ax.plot(t_train, x_smooth[..., coord_ind], color=COLOR.EST, label="smoothed")
|
|
423
|
+
ax.set(xlabel="t", ylabel=coord_name)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _plot_pde_training_data(last_train, last_train_true, smoothed_last_train):
|
|
427
|
+
"""Plot training data (and smoothed training data, if different)."""
|
|
428
|
+
# 1D:
|
|
429
|
+
if len(last_train.shape) == 3:
|
|
430
|
+
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
|
|
431
|
+
axs[0].imshow(last_train_true, vmin=0, vmax=last_train_true.max())
|
|
432
|
+
axs[0].set(title="True Data")
|
|
433
|
+
axs[1].imshow(last_train, vmin=0, vmax=last_train_true.max())
|
|
434
|
+
axs[1].set(title="Noisy Data")
|
|
435
|
+
axs[2].imshow(smoothed_last_train, vmin=0, vmax=last_train_true.max())
|
|
436
|
+
axs[2].set(title="Smoothed Data")
|
|
437
|
+
return plt.show()
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _plot_test_sim_data_1d_panel(
|
|
441
|
+
axs: Sequence[Axes],
|
|
442
|
+
x_true: Optional[np.ndarray],
|
|
443
|
+
x_sim: np.ndarray,
|
|
444
|
+
t_test: np.ndarray,
|
|
445
|
+
t_sim: np.ndarray,
|
|
446
|
+
coord_names: Sequence[str],
|
|
447
|
+
) -> None:
|
|
448
|
+
for ordinate, ax in enumerate(axs):
|
|
449
|
+
if x_true is not None:
|
|
450
|
+
ax.plot(t_test, x_true[:, ordinate], color=COLOR.TRUE, label="True")
|
|
451
|
+
axs[ordinate].plot(
|
|
452
|
+
t_sim, x_sim[:, ordinate], "--", color=COLOR.SIM, label="Simulation"
|
|
453
|
+
)
|
|
454
|
+
axs[ordinate].set(xlabel="t", ylabel=coord_names[ordinate])
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _plot_test_sim_data_2d(
|
|
458
|
+
ax: Axes,
|
|
459
|
+
x_true: Optional[np.ndarray],
|
|
460
|
+
x_sim: np.ndarray,
|
|
461
|
+
labels: bool,
|
|
462
|
+
coord_names: Sequence[str],
|
|
463
|
+
) -> None:
|
|
464
|
+
if x_true is not None:
|
|
465
|
+
ax.plot(x_true[:, 0], x_true[:, 1], color=COLOR.TRUE, label="True Values")
|
|
466
|
+
ax.plot(x_sim[:, 0], x_sim[:, 1], "--", color=COLOR.SIM, label="Simulation")
|
|
467
|
+
if labels:
|
|
468
|
+
ax.set(xlabel=coord_names[0], ylabel=coord_names[1])
|
|
469
|
+
else:
|
|
470
|
+
ax.set(xticks=[], yticks=[])
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _plot_test_sim_data_3d(
|
|
474
|
+
ax: Axes, x_vals: np.ndarray, label: Optional[str], coord_names: Sequence[str]
|
|
475
|
+
):
|
|
476
|
+
if label == "True":
|
|
477
|
+
color = COLOR.TRUE
|
|
478
|
+
elif label == "Simulation":
|
|
479
|
+
color = COLOR.SIM
|
|
480
|
+
else:
|
|
481
|
+
color = None
|
|
482
|
+
ax.plot(x_vals[:, 0], x_vals[:, 1], x_vals[:, 2], color=color, label=label)
|
|
483
|
+
if label:
|
|
484
|
+
ax.set(xlabel=coord_names[0], ylabel=coord_names[1], zlabel=coord_names[2])
|
|
485
|
+
else:
|
|
486
|
+
ax.set(xticks=[], yticks=[], zticks=[])
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def plot_test_trajectory(
|
|
490
|
+
x_true: np.ndarray,
|
|
491
|
+
x_sim: np.ndarray,
|
|
492
|
+
t_test: np.ndarray,
|
|
493
|
+
t_sim: np.ndarray,
|
|
494
|
+
figs: Optional[tuple[Figure, Figure]] = None,
|
|
495
|
+
coord_names: Optional[Sequence[str]] = None,
|
|
496
|
+
) -> tuple[Figure, Figure]:
|
|
497
|
+
"""Plot a test trajectory
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
last_test: a single trajectory of the system
|
|
501
|
+
model: a trained model to simulate and compare to test data
|
|
502
|
+
dt: the time interval in test data
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
The sequence of axes used for the single-dimension time-series plots.
|
|
506
|
+
If ``axs`` is provided, the same sequence is returned.
|
|
507
|
+
"""
|
|
508
|
+
if coord_names is None:
|
|
509
|
+
coord_names = [f"$x_{i}$" for i in range(x_true.shape[1])]
|
|
510
|
+
if not figs:
|
|
511
|
+
fig_by_coord_1d, axs_by_coord = plt.subplots(
|
|
512
|
+
x_true.shape[1], 1, sharex=True, figsize=(7, 9)
|
|
513
|
+
)
|
|
514
|
+
if x_true.shape[1] == 2:
|
|
515
|
+
fig_composite, axs_composite = plt.subplots(1, 2, figsize=(10, 4.5))
|
|
516
|
+
elif x_true.shape[1] == 3:
|
|
517
|
+
fig_composite, axs_composite = plt.subplots(
|
|
518
|
+
1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"}
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
raise ValueError("Can only plot 2d or 3d data.")
|
|
522
|
+
else:
|
|
523
|
+
fig_composite, fig_by_coord_1d = figs
|
|
524
|
+
axs_composite = fig_composite.axes
|
|
525
|
+
axs_by_coord = fig_by_coord_1d.axes
|
|
526
|
+
|
|
527
|
+
assert isinstance(axs_composite, list)
|
|
528
|
+
assert isinstance(axs_by_coord, list)
|
|
529
|
+
_plot_test_sim_data_1d_panel(axs_by_coord, None, x_sim, t_test, t_sim, coord_names)
|
|
530
|
+
axs_by_coord[-1].legend()
|
|
531
|
+
if x_true.shape[1] == 2:
|
|
532
|
+
_plot_test_sim_data_2d(
|
|
533
|
+
axs_composite[0], None, x_sim, labels=True, coord_names=coord_names
|
|
534
|
+
)
|
|
535
|
+
elif x_true.shape[1] == 3:
|
|
536
|
+
_plot_test_sim_data_3d(axs_composite[0], x_sim, "Simulation", coord_names)
|
|
537
|
+
axs_composite[0].legend()
|
|
538
|
+
_plot_data_psd(axs_composite[1], x_sim, coord_names, traj_type="sim")
|
|
539
|
+
if not figs:
|
|
540
|
+
fig_by_coord_1d.suptitle("Test Trajectories by Dimension")
|
|
541
|
+
fig_composite.suptitle("Full Test Trajectories")
|
|
542
|
+
axs_composite[0].set(title="true trajectory")
|
|
543
|
+
axs_composite[0].set(title="model simulation")
|
|
544
|
+
return fig_composite, fig_by_coord_1d
|
sindy_exp/_typing.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
Callable,
|
|
6
|
+
Literal,
|
|
7
|
+
NamedTuple,
|
|
8
|
+
Optional,
|
|
9
|
+
Protocol,
|
|
10
|
+
TypeVar,
|
|
11
|
+
overload,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import sympy as sp
|
|
16
|
+
from numpy.typing import NBitBase
|
|
17
|
+
from sympy import Expr
|
|
18
|
+
from typing_extensions import Self
|
|
19
|
+
|
|
20
|
+
NpFlt = np.dtype[np.floating]
|
|
21
|
+
Float1D = np.ndarray[tuple[int], NpFlt]
|
|
22
|
+
Float2D = np.ndarray[tuple[int, int], NpFlt]
|
|
23
|
+
Shape = TypeVar("Shape", bound=tuple[int, ...])
|
|
24
|
+
FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _BaseSINDy(Protocol):
|
|
31
|
+
optimizer: Any
|
|
32
|
+
feature_library: Any
|
|
33
|
+
feature_names: list[str]
|
|
34
|
+
|
|
35
|
+
def fit(self, x: TrajectoryType, t: TrajectoryType, *args, **kwargs) -> Self: ...
|
|
36
|
+
|
|
37
|
+
def simulate(self, x0: np.ndarray, t: np.ndarray, **kwargs) -> np.ndarray: ...
|
|
38
|
+
|
|
39
|
+
def score(
|
|
40
|
+
self,
|
|
41
|
+
x: TrajectoryType,
|
|
42
|
+
t: TrajectoryType,
|
|
43
|
+
x_dot: TrajectoryType,
|
|
44
|
+
metric: Callable,
|
|
45
|
+
) -> float: ...
|
|
46
|
+
|
|
47
|
+
def predict(self, x: np.ndarray, u: None | np.ndarray = None) -> np.ndarray: ...
|
|
48
|
+
|
|
49
|
+
def coefficients(self): ...
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def equations(self) -> list[str]: ...
|
|
53
|
+
|
|
54
|
+
@overload
|
|
55
|
+
def equations(self, precision: int) -> list[str]: ...
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def equations(self, precision: int, fmt: Literal["str"] | None) -> list[str]: ...
|
|
59
|
+
|
|
60
|
+
@overload
|
|
61
|
+
def equations(
|
|
62
|
+
self, precision: int, fmt: Literal["sympy"]
|
|
63
|
+
) -> list[dict[Expr, float]]: ...
|
|
64
|
+
|
|
65
|
+
def print(self, precision: int, **kwargs) -> None: ...
|
|
66
|
+
|
|
67
|
+
def get_feature_names(self) -> list[str]: ...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ProbData(NamedTuple):
|
|
71
|
+
"""Data bundle for a single trajectory.
|
|
72
|
+
|
|
73
|
+
Represents a trajectory's training data and associated metadata.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
dt: float
|
|
77
|
+
t_train: Float1D
|
|
78
|
+
x_train: Float2D
|
|
79
|
+
x_train_true: Float2D
|
|
80
|
+
x_train_true_dot: Float2D
|
|
81
|
+
input_features: list[str]
|
|
82
|
+
integrator: Optional[Any] = None # diffrax.Solution
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class NestedDict(defaultdict):
|
|
86
|
+
"""A dictionary that splits all keys by ".", creating a sub-dict.
|
|
87
|
+
|
|
88
|
+
Args: see superclass
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
|
|
92
|
+
>>> foo = NestedDict("a.b"=1)
|
|
93
|
+
>>> foo["a.c"] = 2
|
|
94
|
+
>>> foo["a"]["b"]
|
|
95
|
+
1
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __missing__(self, key):
|
|
99
|
+
try:
|
|
100
|
+
prefix, subkey = key.split(".", 1)
|
|
101
|
+
except ValueError:
|
|
102
|
+
raise KeyError(key)
|
|
103
|
+
return self[prefix][subkey]
|
|
104
|
+
|
|
105
|
+
def __setitem__(self, key, value):
|
|
106
|
+
if "." in key:
|
|
107
|
+
prefix, suffix = key.split(".", 1)
|
|
108
|
+
if self.get(prefix) is None:
|
|
109
|
+
self[prefix] = NestedDict()
|
|
110
|
+
return self[prefix].__setitem__(suffix, value)
|
|
111
|
+
else:
|
|
112
|
+
return super().__setitem__(key, value)
|
|
113
|
+
|
|
114
|
+
def update(self, other: dict): # type: ignore
|
|
115
|
+
try:
|
|
116
|
+
for k, v in other.items():
|
|
117
|
+
self.__setitem__(k, v)
|
|
118
|
+
except: # noqa: E722
|
|
119
|
+
super().update(other)
|
|
120
|
+
|
|
121
|
+
def flatten(self):
|
|
122
|
+
"""Flattens a nested dictionary without mutating. Returns new dict"""
|
|
123
|
+
|
|
124
|
+
def _flatten(nested_d: dict) -> dict:
|
|
125
|
+
new = {}
|
|
126
|
+
for key, value in nested_d.items():
|
|
127
|
+
if not isinstance(key, str):
|
|
128
|
+
raise TypeError("Only string keys allowed in flattening")
|
|
129
|
+
if not isinstance(value, dict):
|
|
130
|
+
new[key] = value
|
|
131
|
+
continue
|
|
132
|
+
for sub_key, sub_value in _flatten(value).items():
|
|
133
|
+
new[key + "." + sub_key] = sub_value
|
|
134
|
+
return new
|
|
135
|
+
|
|
136
|
+
return _flatten(self)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class DynamicsTrialData:
|
|
141
|
+
trajectories: list[ProbData]
|
|
142
|
+
true_equations: list[dict[sp.Expr, float]]
|
|
143
|
+
sindy_equations: list[dict[sp.Expr, float]]
|
|
144
|
+
model: _BaseSINDy
|
|
145
|
+
input_features: list[str]
|
|
146
|
+
smooth_train: list[np.ndarray]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass
|
|
150
|
+
class SINDyTrialUpdate:
|
|
151
|
+
t_sim: Float1D
|
|
152
|
+
t_test: Float1D
|
|
153
|
+
x_sim: FloatND
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@dataclass
|
|
157
|
+
class FullDynamicsTrialData(DynamicsTrialData):
|
|
158
|
+
sims: list[SINDyTrialUpdate]
|