sindy-exp 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sindy_exp/_odes.py ADDED
@@ -0,0 +1,287 @@
1
+ from importlib import resources
2
+ from logging import getLogger
3
+ from typing import Callable, TypeVar
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pysindy as ps
8
+ import sympy as sp
9
+ from dysts.base import DynSys
10
+
11
+ from ._plotting import (
12
+ compare_coefficient_plots_from_dicts,
13
+ plot_test_trajectory,
14
+ plot_training_data,
15
+ )
16
+ from ._typing import (
17
+ DynamicsTrialData,
18
+ FullDynamicsTrialData,
19
+ ProbData,
20
+ SINDyTrialUpdate,
21
+ _BaseSINDy,
22
+ )
23
+ from ._utils import (
24
+ _simulate_test_data,
25
+ coeff_metrics,
26
+ integration_metrics,
27
+ unionize_coeff_dicts,
28
+ )
29
+
30
+ metric_ordering = {
31
+ "coeff_precision": "max",
32
+ "coeff_f1": "max",
33
+ "coeff_recall": "max",
34
+ "coeff_mae": "min",
35
+ "coeff_mse": "min",
36
+ "mse_plot": "min",
37
+ "mae_plot": "min",
38
+ }
39
+
40
+
41
+ T = TypeVar("T", bound=int)
42
+ DType = TypeVar("DType", bound=np.dtype)
43
+ MOD_LOG = getLogger(__name__)
44
+ LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
45
+
46
+
47
+ def _add_forcing(
48
+ forcing_func: Callable[[float], np.ndarray[tuple[T], DType]],
49
+ auto_func: Callable[
50
+ [float, np.ndarray[tuple[T], DType]], np.ndarray[tuple[T], DType]
51
+ ],
52
+ ) -> Callable[[float, np.ndarray], np.ndarray]:
53
+ """Add a time-dependent forcing term to a rhs func
54
+
55
+ Args:
56
+ forcing_func: The forcing function to add
57
+ auto_func: An existing rhs func for solve_ivp
58
+
59
+ Returns:
60
+ A rhs function for integration
61
+ """
62
+
63
+ def sum_of_terms(
64
+ t: float, state: np.ndarray[tuple[T], DType]
65
+ ) -> np.ndarray[tuple[T], DType]:
66
+ return np.array(forcing_func(t)) + np.array(auto_func(t, state))
67
+
68
+ return sum_of_terms
69
+
70
+
71
+ class LotkaVolterra(DynSys):
72
+ """Lotka-Volterra (predator-prey) dynamical system."""
73
+
74
+ nonnegative = True
75
+
76
+ def __init__(self):
77
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
78
+
79
+ @staticmethod
80
+ def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
81
+ """LV dynamics
82
+
83
+ Args:
84
+ x: prey population
85
+ y: predator population
86
+ t: time (ignored, since autonomous)
87
+ alpha: prey growth rate
88
+ beta: predation rate
89
+ delta: predator reproduction rate
90
+ gamma: predator death rate
91
+ """
92
+ dxdt = alpha * x - beta * x * y
93
+ dydt = delta * x * y - gamma * y
94
+ return np.array([dxdt, dydt])
95
+
96
+
97
+ class Hopf(DynSys):
98
+ """Hopf normal form dynamical system."""
99
+
100
+ def __init__(self):
101
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
102
+
103
+ @staticmethod
104
+ def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
105
+ dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
106
+ dydt = omega * x + mu * y - A * (x**2 * y + y**3)
107
+ return np.array([dxdt, dydt])
108
+
109
+
110
+ class SHO(DynSys):
111
+ """Linear damped simple harmonic oscillator"""
112
+
113
+ def __init__(self):
114
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
115
+
116
+ @staticmethod
117
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
118
+ dxdt = a * x + b * y
119
+ dydt = c * x + d * y
120
+ return np.array([dxdt, dydt])
121
+
122
+
123
+ class CubicHO(DynSys):
124
+ """Cubic damped harmonic oscillator."""
125
+
126
+ def __init__(self):
127
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
128
+
129
+ @staticmethod
130
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
131
+ dxdt = a * x**3 + b * y**3
132
+ dydt = c * x**3 + d * y**3
133
+ return np.array([dxdt, dydt])
134
+
135
+
136
+ class VanDerPol(DynSys):
137
+ """Van der Pol oscillator.
138
+
139
+ dx/dt = y
140
+ dy/dt = mu * (1 - x^2) * y - x
141
+ """
142
+
143
+ def __init__(self):
144
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
145
+
146
+ @staticmethod
147
+ def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
148
+ dxdt = x_dot
149
+ dx2dt2 = mu * (1 - x**2) * x_dot - x
150
+ return np.array([dxdt, dx2dt2])
151
+
152
+
153
+ class Kinematics(DynSys):
154
+ """One-dimensional kinematics with constant acceleration.
155
+
156
+ dx/dt = v
157
+ dv/dt = a
158
+ """
159
+
160
+ def __init__(self):
161
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
162
+
163
+ @staticmethod
164
+ def _rhs(x, v, t: float, a) -> np.ndarray:
165
+ dxdt = v
166
+ dvdt = a
167
+ return np.array([dxdt, dvdt])
168
+
169
+
170
+ def fit_eval(
171
+ data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
172
+ model: _BaseSINDy,
173
+ simulations: bool = True,
174
+ display: bool = True,
175
+ return_all: bool = True,
176
+ ) -> dict | tuple[dict, DynamicsTrialData | FullDynamicsTrialData]:
177
+ """Fit and evaluate a SINDy model on a set of trajectories.
178
+
179
+ Args:
180
+ data: Tuple of (trajectories, true_equations), where ``trajectories`` is
181
+ a list of ProbData objects and ``true_equations`` is a list of
182
+ dictionaries mapping SymPy symbols to their true coefficients for
183
+ each state coordinate.
184
+ model: A SINDy-like model implementing the _BaseSINDy protocol.
185
+ simulations: Whether to run forward simulations for evaluation.
186
+ display: Whether to generate plots as part of evaluation.
187
+ return_all: If True, return a dictionary containing metrics and the
188
+ assembled DynamicsTrialData; otherwise return only the metrics
189
+ dictionary.
190
+ """
191
+
192
+ trajectories, true_equations = data
193
+ input_features = trajectories[0].input_features
194
+
195
+ x_train = [traj.x_train for traj in trajectories]
196
+ t_train = [traj.t_train for traj in trajectories]
197
+ model.fit(x_train, t=t_train, feature_names=input_features)
198
+
199
+ MOD_LOG.info(f"Fitting a model: {model}")
200
+ coeff_true_dicts, coeff_est_dicts = unionize_coeff_dicts(model, true_equations)
201
+ if isinstance(model.feature_library, ps.WeakPDELibrary):
202
+ # WeakPDE library fails to simulate, so insert nonweak library
203
+ # to Pipeline and SINDy model.
204
+ inner_lib = model.feature_library.function_library
205
+ model.feature_library = inner_lib # type: ignore # TODO: Fix in pysindy
206
+ if isinstance(model, ps.SINDy) and hasattr(
207
+ model.differentiation_method, "smoothed_x_"
208
+ ):
209
+ smooth_x = []
210
+ for traj in trajectories:
211
+ model.differentiation_method(traj.x_train, t=traj.t_train)
212
+ smooth_x.append(model.differentiation_method.smoothed_x_)
213
+ else: # using WeakPDELibrary
214
+ smooth_x = x_train
215
+ trial_data = DynamicsTrialData(
216
+ trajectories=trajectories,
217
+ true_equations=coeff_true_dicts,
218
+ sindy_equations=coeff_est_dicts,
219
+ model=model,
220
+ input_features=input_features,
221
+ smooth_train=smooth_x,
222
+ )
223
+ MOD_LOG.info(f"Evaluating a model: {model}")
224
+ metrics = coeff_metrics(coeff_est_dicts, coeff_true_dicts)
225
+ if simulations:
226
+ sims: list[SINDyTrialUpdate] = []
227
+ integration_metric_list: list[dict[str, float | np.floating]] = []
228
+ for traj in trajectories:
229
+ sim = _simulate_test_data(model, traj.dt, traj.x_train_true)
230
+ sims.append(sim)
231
+ integration_metric_list.append(
232
+ integration_metrics(
233
+ model,
234
+ traj.x_train_true,
235
+ traj.t_train,
236
+ traj.x_train_true_dot,
237
+ )
238
+ )
239
+
240
+ agg_integration_metrics: dict[str, float | np.floating] = {}
241
+ for key in integration_metric_list[0].keys():
242
+ values = [m[key] for m in integration_metric_list]
243
+ agg_integration_metrics[key] = float(np.mean(values))
244
+ metrics.update(agg_integration_metrics)
245
+
246
+ trial_data = FullDynamicsTrialData(sims=sims, **trial_data.__dict__)
247
+ if display:
248
+ plot_ode_panel(trial_data)
249
+ for i, traj in enumerate(trajectories):
250
+ fig_composite, fig_by_coord_1d = plot_training_data(
251
+ traj.t_train,
252
+ traj.x_train,
253
+ traj.x_train_true,
254
+ x_smooth=smooth_x[i],
255
+ coord_names=input_features,
256
+ )
257
+ if simulations:
258
+ # Overlay test trajectory time series on the coordinate-wise figure
259
+ plot_test_trajectory(
260
+ traj.x_train_true,
261
+ sims[i].x_sim,
262
+ traj.t_train,
263
+ sims[i].t_sim,
264
+ figs=(fig_composite, fig_by_coord_1d),
265
+ coord_names=input_features,
266
+ )
267
+ if return_all:
268
+ return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
269
+ return metrics
270
+
271
+
272
+ def plot_ode_panel(trial_data: DynamicsTrialData):
273
+ trial_data.model.print()
274
+ compare_coefficient_plots_from_dicts(
275
+ trial_data.sindy_equations,
276
+ trial_data.true_equations,
277
+ input_features=[_texify(feat) for feat in trial_data.input_features],
278
+ )
279
+ plt.show()
280
+
281
+
282
+ def _texify(input: str) -> str:
283
+ if input[0] != "$":
284
+ input = "$" + input
285
+ if input[-1] != "$":
286
+ input = input + "$"
287
+ return input