sindy-exp 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sindy_exp/__init__.py +2 -1
- sindy_exp/_data.py +30 -25
- sindy_exp/_diffrax_solver.py +12 -4
- sindy_exp/_odes.py +11 -33
- sindy_exp/_plotting.py +11 -5
- sindy_exp/_typing.py +24 -9
- sindy_exp/_utils.py +3 -4
- {sindy_exp-0.2.2.dist-info → sindy_exp-0.3.0.dist-info}/METADATA +24 -17
- sindy_exp-0.3.0.dist-info/RECORD +15 -0
- sindy_exp-0.2.2.dist-info/RECORD +0 -15
- {sindy_exp-0.2.2.dist-info → sindy_exp-0.3.0.dist-info}/WHEEL +0 -0
- {sindy_exp-0.2.2.dist-info → sindy_exp-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {sindy_exp-0.2.2.dist-info → sindy_exp-0.3.0.dist-info}/top_level.txt +0 -0
sindy_exp/__init__.py
CHANGED
|
@@ -7,13 +7,14 @@ from ._plotting import (
|
|
|
7
7
|
plot_test_trajectory,
|
|
8
8
|
plot_training_data,
|
|
9
9
|
)
|
|
10
|
-
from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData
|
|
10
|
+
from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData, SimProbData
|
|
11
11
|
from ._utils import coeff_metrics, integration_metrics, pred_metrics
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
14
|
"gen_data",
|
|
15
15
|
"fit_eval",
|
|
16
16
|
"ProbData",
|
|
17
|
+
"SimProbData",
|
|
17
18
|
"DynamicsTrialData",
|
|
18
19
|
"FullDynamicsTrialData",
|
|
19
20
|
"coeff_metrics",
|
sindy_exp/_data.py
CHANGED
|
@@ -11,7 +11,7 @@ from dysts.base import DynSys
|
|
|
11
11
|
|
|
12
12
|
from ._dysts_to_sympy import dynsys_to_sympy
|
|
13
13
|
from ._plotting import plot_training_data
|
|
14
|
-
from ._typing import ExperimentResult, Float1D,
|
|
14
|
+
from ._typing import ExperimentResult, Float1D, SimProbData
|
|
15
15
|
from ._utils import _sympy_expr_to_feat_coeff
|
|
16
16
|
|
|
17
17
|
try:
|
|
@@ -42,7 +42,7 @@ def gen_data(
|
|
|
42
42
|
t_end: float = 10,
|
|
43
43
|
display: bool = False,
|
|
44
44
|
array_namespace: str = "numpy",
|
|
45
|
-
) -> ExperimentResult[tuple[list[
|
|
45
|
+
) -> ExperimentResult[tuple[list[SimProbData], list[dict[sp.Expr, float]]]]:
|
|
46
46
|
"""Generate random training and test data
|
|
47
47
|
|
|
48
48
|
An Experiment step according to the mitosis experiment runner.
|
|
@@ -75,7 +75,7 @@ def gen_data(
|
|
|
75
75
|
coeff_true = _sympy_expr_to_feat_coeff(sp_expr)
|
|
76
76
|
rhsfunc = lambda t, X: dyst_sys.rhs(X, t) # noqa: E731
|
|
77
77
|
try:
|
|
78
|
-
x0_center = dyst_sys.ic
|
|
78
|
+
x0_center = cast(Float1D, dyst_sys.ic)
|
|
79
79
|
except KeyError:
|
|
80
80
|
x0_center = np.zeros((len(input_features)), dtype=np.float64)
|
|
81
81
|
try:
|
|
@@ -88,7 +88,7 @@ def gen_data(
|
|
|
88
88
|
noise_abs = 0.1
|
|
89
89
|
|
|
90
90
|
MOD_LOG.info(f"Generating {n_trajectories} trajectories of f{system}")
|
|
91
|
-
prob_data_list: list[
|
|
91
|
+
prob_data_list: list[SimProbData] = []
|
|
92
92
|
if array_namespace == "numpy":
|
|
93
93
|
feature_names = [feat.name for feat in input_features]
|
|
94
94
|
for _ in range(n_trajectories):
|
|
@@ -108,20 +108,20 @@ def gen_data(
|
|
|
108
108
|
prob_data_list.append(prob)
|
|
109
109
|
elif array_namespace == "jax":
|
|
110
110
|
try:
|
|
111
|
-
|
|
112
|
-
except
|
|
111
|
+
jax # type: ignore
|
|
112
|
+
except NameError:
|
|
113
113
|
raise ImportError(
|
|
114
114
|
"jax data generation requested but diffrax or sympy2jax not"
|
|
115
115
|
" installed"
|
|
116
116
|
)
|
|
117
|
-
this_seed = jax.random.PRNGKey(seed)
|
|
117
|
+
this_seed = jax.random.PRNGKey(seed) # type: ignore
|
|
118
118
|
for _ in range(n_trajectories):
|
|
119
|
-
this_seed, _ = jax.random.split(this_seed)
|
|
120
|
-
prob = _gen_data_jax(
|
|
119
|
+
this_seed, _ = jax.random.split(this_seed) # type: ignore
|
|
120
|
+
prob = _gen_data_jax( # type: ignore
|
|
121
121
|
sp_expr,
|
|
122
122
|
input_features,
|
|
123
123
|
this_seed,
|
|
124
|
-
x0_center=x0_center,
|
|
124
|
+
x0_center=x0_center, # type: ignore # numpy->jax
|
|
125
125
|
nonnegative=nonnegative,
|
|
126
126
|
ic_stdev=ic_stdev,
|
|
127
127
|
noise_abs=noise_abs,
|
|
@@ -136,8 +136,10 @@ def gen_data(
|
|
|
136
136
|
)
|
|
137
137
|
if display and prob_data_list:
|
|
138
138
|
sample = prob_data_list[0]
|
|
139
|
+
assert sample.x_train_true is not None # typing
|
|
139
140
|
figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
|
|
140
141
|
figs[0].suptitle("Sample Trajectory")
|
|
142
|
+
|
|
141
143
|
return {
|
|
142
144
|
"data": (prob_data_list, coeff_true),
|
|
143
145
|
"main": f"{n_trajectories} trajectories of {rhsfunc}",
|
|
@@ -156,7 +158,7 @@ def _gen_data(
|
|
|
156
158
|
nonnegative: bool,
|
|
157
159
|
dt: float,
|
|
158
160
|
t_end: float,
|
|
159
|
-
) ->
|
|
161
|
+
) -> SimProbData:
|
|
160
162
|
rng = np.random.default_rng(seed)
|
|
161
163
|
t_train = np.arange(0, t_end, dt)
|
|
162
164
|
t_train_span = (t_train[0], t_train[-1])
|
|
@@ -180,8 +182,9 @@ def _gen_data(
|
|
|
180
182
|
noise_abs = np.sqrt(_signal_avg_power(x_train) * noise_rel)
|
|
181
183
|
x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
|
|
182
184
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
+
assert noise_abs is not None # typing
|
|
186
|
+
return SimProbData(
|
|
187
|
+
t_train, x_train, input_features, x_train_true, x_train_true_dot, noise_abs
|
|
185
188
|
)
|
|
186
189
|
|
|
187
190
|
|
|
@@ -208,10 +211,12 @@ class LotkaVolterra(DynSys):
|
|
|
208
211
|
nonnegative = True
|
|
209
212
|
|
|
210
213
|
def __init__(self):
|
|
211
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
214
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
212
215
|
|
|
213
216
|
@staticmethod
|
|
214
|
-
def _rhs(
|
|
217
|
+
def _rhs( # type: ignore # dysts
|
|
218
|
+
x, y, t: float, alpha, beta, gamma, delta
|
|
219
|
+
) -> np.ndarray:
|
|
215
220
|
"""LV dynamics
|
|
216
221
|
|
|
217
222
|
Args:
|
|
@@ -233,10 +238,10 @@ class Hopf(DynSys):
|
|
|
233
238
|
"""Hopf normal form dynamical system."""
|
|
234
239
|
|
|
235
240
|
def __init__(self):
|
|
236
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
241
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
237
242
|
|
|
238
243
|
@staticmethod
|
|
239
|
-
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
|
|
244
|
+
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray: # type: ignore # dysts
|
|
240
245
|
dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
|
|
241
246
|
dydt = omega * x + mu * y - A * (x**2 * y + y**3)
|
|
242
247
|
return np.array([dxdt, dydt])
|
|
@@ -247,10 +252,10 @@ class SHO(DynSys):
|
|
|
247
252
|
"""Linear damped simple harmonic oscillator"""
|
|
248
253
|
|
|
249
254
|
def __init__(self):
|
|
250
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
255
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
251
256
|
|
|
252
257
|
@staticmethod
|
|
253
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
258
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
|
|
254
259
|
dxdt = a * x + b * y
|
|
255
260
|
dydt = c * x + d * y
|
|
256
261
|
return np.array([dxdt, dydt])
|
|
@@ -261,10 +266,10 @@ class CubicHO(DynSys):
|
|
|
261
266
|
"""Cubic damped harmonic oscillator."""
|
|
262
267
|
|
|
263
268
|
def __init__(self):
|
|
264
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
269
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
265
270
|
|
|
266
271
|
@staticmethod
|
|
267
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
272
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
|
|
268
273
|
dxdt = a * x**3 + b * y**3
|
|
269
274
|
dydt = c * x**3 + d * y**3
|
|
270
275
|
return np.array([dxdt, dydt])
|
|
@@ -279,10 +284,10 @@ class VanDerPol(DynSys):
|
|
|
279
284
|
"""
|
|
280
285
|
|
|
281
286
|
def __init__(self):
|
|
282
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
287
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
283
288
|
|
|
284
289
|
@staticmethod
|
|
285
|
-
def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
|
|
290
|
+
def _rhs(x, x_dot, t: float, mu) -> np.ndarray: # type: ignore # dysts
|
|
286
291
|
dxdt = x_dot
|
|
287
292
|
dx2dt2 = mu * (1 - x**2) * x_dot - x
|
|
288
293
|
return np.array([dxdt, dx2dt2])
|
|
@@ -297,10 +302,10 @@ class Kinematics(DynSys):
|
|
|
297
302
|
"""
|
|
298
303
|
|
|
299
304
|
def __init__(self):
|
|
300
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
305
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
301
306
|
|
|
302
307
|
@staticmethod
|
|
303
|
-
def _rhs(x, v, t: float, a) -> np.ndarray:
|
|
308
|
+
def _rhs(x, v, t: float, a) -> np.ndarray: # type: ignore # dysts
|
|
304
309
|
dxdt = v
|
|
305
310
|
dvdt = a
|
|
306
311
|
return np.array([dxdt, dvdt])
|
sindy_exp/_diffrax_solver.py
CHANGED
|
@@ -6,7 +6,7 @@ import jax.numpy as jnp
|
|
|
6
6
|
import sympy2jax
|
|
7
7
|
from sympy import Expr, Symbol
|
|
8
8
|
|
|
9
|
-
from ._typing import
|
|
9
|
+
from ._typing import SimProbData
|
|
10
10
|
|
|
11
11
|
jax.config.update("jax_enable_x64", True)
|
|
12
12
|
|
|
@@ -22,7 +22,7 @@ def _gen_data_jax(
|
|
|
22
22
|
nonnegative: bool,
|
|
23
23
|
dt: float,
|
|
24
24
|
t_end: float,
|
|
25
|
-
) ->
|
|
25
|
+
) -> SimProbData:
|
|
26
26
|
rhstree = sympy2jax.SymbolicModule(exprs)
|
|
27
27
|
|
|
28
28
|
def ode_sys(t, state, args):
|
|
@@ -71,6 +71,8 @@ def _gen_data_jax(
|
|
|
71
71
|
if noise_abs is None:
|
|
72
72
|
assert noise_rel is not None # force type narrowing
|
|
73
73
|
noise_abs = float(jnp.sqrt(_signal_avg_power(x_train_true)) * noise_rel)
|
|
74
|
+
else:
|
|
75
|
+
noise_rel = noise_abs / float(jnp.sqrt(_signal_avg_power(x_train_true)))
|
|
74
76
|
|
|
75
77
|
x_train = x_train_true + jax.random.normal(key, x_train_true.shape) * noise_abs
|
|
76
78
|
|
|
@@ -78,8 +80,14 @@ def _gen_data_jax(
|
|
|
78
80
|
x_train_true_dot = jnp.array([ode_sys(0, xi, None) for xi in x_train_true])
|
|
79
81
|
|
|
80
82
|
stringy_features = [sym.name for sym in input_features]
|
|
81
|
-
return
|
|
82
|
-
|
|
83
|
+
return SimProbData(
|
|
84
|
+
t_train, # type: ignore # jax->numpy
|
|
85
|
+
x_train, # type: ignore # jax->numpy
|
|
86
|
+
stringy_features,
|
|
87
|
+
x_train_true, # type: ignore # jax->numpy
|
|
88
|
+
x_train_true_dot, # type: ignore # jax->numpy
|
|
89
|
+
noise_abs,
|
|
90
|
+
sol,
|
|
83
91
|
)
|
|
84
92
|
|
|
85
93
|
|
sindy_exp/_odes.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from logging import getLogger
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Literal, TypeVar, cast, overload
|
|
3
3
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
@@ -14,7 +14,7 @@ from ._typing import (
|
|
|
14
14
|
DynamicsTrialData,
|
|
15
15
|
ExperimentResult,
|
|
16
16
|
FullDynamicsTrialData,
|
|
17
|
-
|
|
17
|
+
SimProbData,
|
|
18
18
|
SINDyTrialUpdate,
|
|
19
19
|
_BaseSINDy,
|
|
20
20
|
)
|
|
@@ -41,33 +41,9 @@ DType = TypeVar("DType", bound=np.dtype)
|
|
|
41
41
|
MOD_LOG = getLogger(__name__)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def _add_forcing(
|
|
45
|
-
forcing_func: Callable[[float], np.ndarray[tuple[T], DType]],
|
|
46
|
-
auto_func: Callable[
|
|
47
|
-
[float, np.ndarray[tuple[T], DType]], np.ndarray[tuple[T], DType]
|
|
48
|
-
],
|
|
49
|
-
) -> Callable[[float, np.ndarray], np.ndarray]:
|
|
50
|
-
"""Add a time-dependent forcing term to a rhs func
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
forcing_func: The forcing function to add
|
|
54
|
-
auto_func: An existing rhs func for solve_ivp
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
A rhs function for integration
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
def sum_of_terms(
|
|
61
|
-
t: float, state: np.ndarray[tuple[T], DType]
|
|
62
|
-
) -> np.ndarray[tuple[T], DType]:
|
|
63
|
-
return np.array(forcing_func(t)) + np.array(auto_func(t, state))
|
|
64
|
-
|
|
65
|
-
return sum_of_terms
|
|
66
|
-
|
|
67
|
-
|
|
68
44
|
@overload
|
|
69
45
|
def fit_eval(
|
|
70
|
-
data: tuple[list[
|
|
46
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
71
47
|
model: _BaseSINDy,
|
|
72
48
|
simulations: Literal[False],
|
|
73
49
|
display: bool,
|
|
@@ -76,7 +52,7 @@ def fit_eval(
|
|
|
76
52
|
|
|
77
53
|
@overload
|
|
78
54
|
def fit_eval(
|
|
79
|
-
data: tuple[list[
|
|
55
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
80
56
|
model: _BaseSINDy,
|
|
81
57
|
simulations: Literal[True],
|
|
82
58
|
display: bool,
|
|
@@ -84,7 +60,7 @@ def fit_eval(
|
|
|
84
60
|
|
|
85
61
|
|
|
86
62
|
def fit_eval(
|
|
87
|
-
data: tuple[list[
|
|
63
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
88
64
|
model: Any,
|
|
89
65
|
simulations: bool = True,
|
|
90
66
|
display: bool = True,
|
|
@@ -93,7 +69,7 @@ def fit_eval(
|
|
|
93
69
|
|
|
94
70
|
Args:
|
|
95
71
|
data: Tuple of (trajectories, true_equations), where ``trajectories`` is
|
|
96
|
-
a list of
|
|
72
|
+
a list of SimProbData objects and ``true_equations`` is a list of
|
|
97
73
|
dictionaries mapping SymPy symbols to their true coefficients for
|
|
98
74
|
each state coordinate.
|
|
99
75
|
model: A SINDy-like model implementing the _BaseSINDy protocol.
|
|
@@ -101,6 +77,8 @@ def fit_eval(
|
|
|
101
77
|
display: Whether to generate plots as part of evaluation.
|
|
102
78
|
"""
|
|
103
79
|
model = cast(_BaseSINDy, model)
|
|
80
|
+
for trajectory in data[0]:
|
|
81
|
+
assert trajectory.x_train_true is not None
|
|
104
82
|
trajectories, true_equations = data
|
|
105
83
|
input_features = trajectories[0].input_features
|
|
106
84
|
|
|
@@ -143,7 +121,7 @@ def fit_eval(
|
|
|
143
121
|
sims: list[SINDyTrialUpdate] = []
|
|
144
122
|
integration_metric_list: list[dict[str, float | np.floating]] = []
|
|
145
123
|
for traj in trajectories:
|
|
146
|
-
sim = _simulate_test_data(model, traj.
|
|
124
|
+
sim = _simulate_test_data(model, traj.t_train, traj.x_train_true)
|
|
147
125
|
sims.append(sim)
|
|
148
126
|
integration_metric_list.append(
|
|
149
127
|
integration_metrics(
|
|
@@ -154,9 +132,9 @@ def fit_eval(
|
|
|
154
132
|
)
|
|
155
133
|
)
|
|
156
134
|
|
|
157
|
-
agg_integration_metrics: dict[str, float
|
|
135
|
+
agg_integration_metrics: dict[str, float] = {}
|
|
158
136
|
for key in integration_metric_list[0].keys():
|
|
159
|
-
values = [m[key] for m in integration_metric_list]
|
|
137
|
+
values = cast(list[float], [m[key] for m in integration_metric_list])
|
|
160
138
|
agg_integration_metrics[key] = float(np.mean(values))
|
|
161
139
|
metrics.update(agg_integration_metrics)
|
|
162
140
|
|
sindy_exp/_plotting.py
CHANGED
|
@@ -57,7 +57,7 @@ def plot_coefficients(
|
|
|
57
57
|
feature_names: Sequence[str],
|
|
58
58
|
ax: Axes,
|
|
59
59
|
**heatmap_kws,
|
|
60
|
-
) ->
|
|
60
|
+
) -> Axes:
|
|
61
61
|
"""Plot a set of dynamical system coefficients in a heatmap.
|
|
62
62
|
|
|
63
63
|
Args:
|
|
@@ -162,6 +162,7 @@ def _compare_coefficient_plots_impl(
|
|
|
162
162
|
1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True
|
|
163
163
|
)
|
|
164
164
|
fig.tight_layout()
|
|
165
|
+
assert axs is not None # type narrowing
|
|
165
166
|
|
|
166
167
|
vmax = signed_root(max_val)
|
|
167
168
|
|
|
@@ -275,7 +276,12 @@ def _plot_training_trajectory(
|
|
|
275
276
|
"""
|
|
276
277
|
if x_train.shape[1] == 2:
|
|
277
278
|
ax.plot(
|
|
278
|
-
x_true[:, 0],
|
|
279
|
+
x_true[:, 0],
|
|
280
|
+
x_true[:, 1],
|
|
281
|
+
".",
|
|
282
|
+
label="True",
|
|
283
|
+
color=COLOR.TRUE,
|
|
284
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
279
285
|
)
|
|
280
286
|
ax.plot(
|
|
281
287
|
x_train[:, 0],
|
|
@@ -283,7 +289,7 @@ def _plot_training_trajectory(
|
|
|
283
289
|
".",
|
|
284
290
|
label="Measured",
|
|
285
291
|
color=COLOR.MEAS,
|
|
286
|
-
**PLOT_KWS,
|
|
292
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
287
293
|
)
|
|
288
294
|
if (
|
|
289
295
|
x_smooth is not None
|
|
@@ -295,7 +301,7 @@ def _plot_training_trajectory(
|
|
|
295
301
|
".",
|
|
296
302
|
label="Smoothed",
|
|
297
303
|
color=COLOR.EST,
|
|
298
|
-
**PLOT_KWS,
|
|
304
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
299
305
|
)
|
|
300
306
|
if labels:
|
|
301
307
|
ax.set(xlabel="$x_0$", ylabel="$x_1$")
|
|
@@ -308,7 +314,7 @@ def _plot_training_trajectory(
|
|
|
308
314
|
x_true[:, 2],
|
|
309
315
|
color=COLOR.TRUE,
|
|
310
316
|
label="True values",
|
|
311
|
-
**PLOT_KWS,
|
|
317
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
312
318
|
)
|
|
313
319
|
|
|
314
320
|
ax.plot(
|
sindy_exp/_typing.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
+
from collections.abc import Mapping
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import (
|
|
4
5
|
Any,
|
|
5
6
|
Callable,
|
|
6
7
|
Literal,
|
|
7
|
-
NamedTuple,
|
|
8
8
|
Optional,
|
|
9
9
|
Protocol,
|
|
10
10
|
TypedDict,
|
|
@@ -31,9 +31,9 @@ TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
|
|
|
31
31
|
class ExperimentResult[T](TypedDict):
|
|
32
32
|
"""Results from a SINDy ODE experiment."""
|
|
33
33
|
|
|
34
|
-
metrics: float
|
|
34
|
+
metrics: Mapping[str, float | None]
|
|
35
35
|
data: T
|
|
36
|
-
main:
|
|
36
|
+
main: object
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class _BaseSINDy(Protocol):
|
|
@@ -71,23 +71,38 @@ class _BaseSINDy(Protocol):
|
|
|
71
71
|
self, precision: int, fmt: Literal["sympy"]
|
|
72
72
|
) -> list[dict[Expr, float]]: ...
|
|
73
73
|
|
|
74
|
+
@overload
|
|
75
|
+
def print(self, **kwargs) -> None: ...
|
|
76
|
+
|
|
77
|
+
@overload
|
|
74
78
|
def print(self, precision: int, **kwargs) -> None: ...
|
|
75
79
|
|
|
76
80
|
def get_feature_names(self) -> list[str]: ...
|
|
77
81
|
|
|
78
82
|
|
|
79
|
-
|
|
80
|
-
|
|
83
|
+
@dataclass
|
|
84
|
+
class ProbData:
|
|
85
|
+
"""Represents a single trajectory's data.
|
|
81
86
|
|
|
82
|
-
|
|
87
|
+
For measured data, only t_train, x_train, and input_features are required.
|
|
83
88
|
"""
|
|
84
89
|
|
|
85
|
-
dt: float
|
|
86
90
|
t_train: Float1D
|
|
87
91
|
x_train: Float2D
|
|
92
|
+
input_features: list[str]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class SimProbData(ProbData):
|
|
97
|
+
"""For simulated data, the noiseless trajectory is known.
|
|
98
|
+
|
|
99
|
+
Optionally includes the integrator solution object for evaluating
|
|
100
|
+
at other points.
|
|
101
|
+
"""
|
|
102
|
+
|
|
88
103
|
x_train_true: Float2D
|
|
89
104
|
x_train_true_dot: Float2D
|
|
90
|
-
|
|
105
|
+
noise_abs: float
|
|
91
106
|
integrator: Optional[Any] = None # diffrax.Solution
|
|
92
107
|
|
|
93
108
|
|
|
@@ -147,7 +162,7 @@ class NestedDict(defaultdict):
|
|
|
147
162
|
|
|
148
163
|
@dataclass
|
|
149
164
|
class DynamicsTrialData:
|
|
150
|
-
trajectories: list[
|
|
165
|
+
trajectories: list[SimProbData]
|
|
151
166
|
true_equations: list[dict[sp.Expr, float]]
|
|
152
167
|
sindy_equations: list[dict[sp.Expr, float]]
|
|
153
168
|
model: _BaseSINDy
|
sindy_exp/_utils.py
CHANGED
|
@@ -149,7 +149,7 @@ def opt_lookup(kind):
|
|
|
149
149
|
def coeff_metrics(
|
|
150
150
|
coeff_est_dicts: list[dict[sp.Expr, float]],
|
|
151
151
|
coeff_true_dicts: list[dict[sp.Expr, float]],
|
|
152
|
-
) -> dict[str, float
|
|
152
|
+
) -> dict[str, float]:
|
|
153
153
|
"""Compute coefficient metrics from aligned coefficient dictionaries.
|
|
154
154
|
|
|
155
155
|
Both arguments are expected to be lists of coefficient dictionaries sharing
|
|
@@ -202,7 +202,7 @@ def coeff_metrics(
|
|
|
202
202
|
coeff_true.flatten(), coefficients.flatten()
|
|
203
203
|
)
|
|
204
204
|
metrics["main"] = metrics["coeff_f1"]
|
|
205
|
-
return metrics
|
|
205
|
+
return {k: float(v) for k, v in metrics.items()}
|
|
206
206
|
|
|
207
207
|
|
|
208
208
|
def pred_metrics(
|
|
@@ -284,7 +284,7 @@ def unionize_coeff_dicts(
|
|
|
284
284
|
|
|
285
285
|
|
|
286
286
|
def _simulate_test_data(
|
|
287
|
-
model: _BaseSINDy,
|
|
287
|
+
model: _BaseSINDy, t_test: Float1D, x_test: Float2D
|
|
288
288
|
) -> SINDyTrialUpdate:
|
|
289
289
|
"""Add simulation data to grid_data
|
|
290
290
|
|
|
@@ -292,7 +292,6 @@ def _simulate_test_data(
|
|
|
292
292
|
Returns:
|
|
293
293
|
Complete GridPointData
|
|
294
294
|
"""
|
|
295
|
-
t_test = cast(Float1D, np.arange(0, len(x_test) * dt, step=dt))
|
|
296
295
|
t_sim = t_test
|
|
297
296
|
try:
|
|
298
297
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sindy-exp
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A basic library for constructing dynamics experiments
|
|
5
5
|
Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -34,7 +34,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
34
34
|
Classifier: License :: OSI Approved :: MIT License
|
|
35
35
|
Classifier: Natural Language :: English
|
|
36
36
|
Classifier: Operating System :: POSIX :: Linux
|
|
37
|
-
Requires-Python: >=3.
|
|
37
|
+
Requires-Python: >=3.12
|
|
38
38
|
Description-Content-Type: text/markdown
|
|
39
39
|
License-File: LICENSE
|
|
40
40
|
Requires-Dist: matplotlib
|
|
@@ -43,6 +43,7 @@ Requires-Dist: seaborn
|
|
|
43
43
|
Requires-Dist: scipy
|
|
44
44
|
Requires-Dist: sympy
|
|
45
45
|
Requires-Dist: dysts
|
|
46
|
+
Requires-Dist: scikit-learn
|
|
46
47
|
Provides-Extra: jax
|
|
47
48
|
Requires-Dist: jax[cuda12]; extra == "jax"
|
|
48
49
|
Requires-Dist: diffrax; extra == "jax"
|
|
@@ -64,14 +65,20 @@ Requires-Dist: tomli; extra == "dev"
|
|
|
64
65
|
Requires-Dist: pysindy>=2.1.0; extra == "dev"
|
|
65
66
|
Dynamic: license-file
|
|
66
67
|
|
|
67
|
-
#
|
|
68
|
+
# Overview
|
|
68
69
|
|
|
69
|
-
A library for constructing dynamics experiments.
|
|
70
|
-
This includes data generation and
|
|
70
|
+
A library for constructing dynamics experiments from the dynamics models in the `dysts` package.
|
|
71
|
+
This includes data generation and model evaluation.
|
|
72
|
+
The first contribution is the static typing of trajectory data (`ProbData`) that, I believe, provides the necessary information to be useful in evaluating a wide variety of dynamics/time-series learning methods.
|
|
73
|
+
The second contribution is the collection of utility functions for designing dynamics learning experiments.
|
|
74
|
+
The third contribution is the collection of such experiments for evaluating dynamics/time-series learning models that meet the `BaseSINDy` API.
|
|
75
|
+
|
|
76
|
+
It aims to (a) be amenable to both `numpy` and `jax` arrays, (b) be usable by any dynamics/time-series learning models that meet the `BaseSINDy` or scikit-time API.
|
|
77
|
+
Internally, this package is used/will be used in benchmarking pysindy runtime/memory usage and choosing default hyperparameters.
|
|
71
78
|
|
|
72
79
|
## Getting started
|
|
73
80
|
|
|
74
|
-
|
|
81
|
+
Install with `pip install sindy-exp` or `pip install sindy-exp[jax]`.
|
|
75
82
|
|
|
76
83
|
Generate data
|
|
77
84
|
|
|
@@ -86,26 +93,26 @@ Evaluate your SINDy-like model with:
|
|
|
86
93
|
A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
|
|
87
94
|
of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
|
|
88
95
|
|
|
89
|
-
## ODE
|
|
96
|
+
## ODE & Data Model
|
|
97
|
+
|
|
98
|
+
Generated or measured data has the dataclass type `ProbData` or `SimProbData`, respectively,
|
|
99
|
+
to indicate whether it includes ground truth information and a noise level.
|
|
100
|
+
If the data is generated in jax, it will have an integrator that can later be used to evaluate the true data on collocation points.
|
|
90
101
|
|
|
91
102
|
We deal primarily with autonomous ODE systems of the form:
|
|
92
103
|
|
|
93
104
|
dx/dt = sum_i f_i(x)
|
|
94
105
|
|
|
95
|
-
|
|
106
|
+
We represent ODE systems as a list of right-hand side expressions.
|
|
96
107
|
Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
|
|
108
|
+
Thus, the rhs of an ODE is of type: `list[dict[sympy.Expr, float]]`
|
|
97
109
|
|
|
98
110
|
## Other useful imports, compatibility, and extensions
|
|
99
111
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
To integrate your own experiments or data generation in a way that is compatible,
|
|
105
|
-
see the `ProbData` and `DynamicsTrialData` classes.
|
|
106
|
-
For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`,
|
|
107
|
-
`plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
108
|
-
For metrics, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
112
|
+
* The experiments are built to be compatible with the `mitosis` tool, an experiment runner. Mitosis is not a dependency, however, to allow using other experiment runners.
|
|
113
|
+
* To integrate your own experiments or data generation in a way that is compatible, see the `ProbData`, `SimProbData`, `DynamicsTrialData`, and `FullDynamicsTrialData` classes.
|
|
114
|
+
* For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`, `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
115
|
+
* For evaluation of models, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
109
116
|
|
|
110
117
|

|
|
111
118
|

|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
sindy_exp/__init__.py,sha256=DJuJM0QVe7zmpOHY--7U8m4MCU8XkL_WKGD1H4skSM0,773
|
|
2
|
+
sindy_exp/_data.py,sha256=GwGAnLvEgxRjrwGLTe8RC6GAIOocJtDBey1Dhj8Gyk4,9744
|
|
3
|
+
sindy_exp/_diffrax_solver.py,sha256=HvhmUT9u9A06c6XRSgewdbKl0JEVB312UMmI_Lb4DVo,2634
|
|
4
|
+
sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
|
|
5
|
+
sindy_exp/_odes.py,sha256=E7BdvFoPI7kxNh0Q9WMpE7MsKGbke5zDjX5Ml31amkQ,5927
|
|
6
|
+
sindy_exp/_plotting.py,sha256=mfsDz_693XQS6tMZ8iQluRJqZMcmxyoZeIMP4DfJX2s,17688
|
|
7
|
+
sindy_exp/_typing.py,sha256=hxIxBx52oKMjd1LA8PItwrXlP-PtNGLoZm6oIMZi18I,4536
|
|
8
|
+
sindy_exp/_utils.py,sha256=RRFtk1iFWeo0Rpw1vYhoNSoABLJcRmGo1sIgTZ9lnWM,11299
|
|
9
|
+
sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
|
|
10
|
+
sindy_exp/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
sindy_exp-0.3.0.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
|
|
12
|
+
sindy_exp-0.3.0.dist-info/METADATA,sha256=qTE2umRXW5K9mCtTIt2u-vvtht92rxII90K94ZYaPOE,5669
|
|
13
|
+
sindy_exp-0.3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
+
sindy_exp-0.3.0.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
|
|
15
|
+
sindy_exp-0.3.0.dist-info/RECORD,,
|
sindy_exp-0.2.2.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
sindy_exp/__init__.py,sha256=OH4tzHuhWmgXexz-QbXMZtuzkv9kZYyIFBUmDa9n73Q,741
|
|
2
|
-
sindy_exp/_data.py,sha256=_PMoXN4JHhR0bZi2ivD9_EB7rUCJC2xhDuRd4J8edgM,9233
|
|
3
|
-
sindy_exp/_diffrax_solver.py,sha256=c-IjDqaAwLj0rZ4vIm8pMm1U_9K6YiE1XCaL72NAVpI,2362
|
|
4
|
-
sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
|
|
5
|
-
sindy_exp/_odes.py,sha256=cMUNvS6TL4_nRLrjK2MBoLzgpM__gu4NwzY2n4gfQn8,6513
|
|
6
|
-
sindy_exp/_plotting.py,sha256=dpcqAXKzb0mSVl0p2WyMadVoGhQi43oL6ZqsbuheEuk,17470
|
|
7
|
-
sindy_exp/_typing.py,sha256=_KKtGcXOZmlR3Fg77G6TmlH5eKDwnoHjusex2Gxlf_4,4196
|
|
8
|
-
sindy_exp/_utils.py,sha256=zR9Npjl8PeeSp6MHHD5lL3q37gDXIOLTEfDqGdY_2fM,11341
|
|
9
|
-
sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
|
|
10
|
-
sindy_exp/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
sindy_exp-0.2.2.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
|
|
12
|
-
sindy_exp-0.2.2.dist-info/METADATA,sha256=JU-BigVCn7w2WhWDRfA3d32RyEz7Yb4EKjxc1ng-a4o,4514
|
|
13
|
-
sindy_exp-0.2.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
-
sindy_exp-0.2.2.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
|
|
15
|
-
sindy_exp-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|