sindy-exp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sindy_exp/__init__.py +28 -0
- sindy_exp/_data.py +202 -0
- sindy_exp/_diffrax_solver.py +104 -0
- sindy_exp/_dysts_to_sympy.py +452 -0
- sindy_exp/_odes.py +287 -0
- sindy_exp/_plotting.py +544 -0
- sindy_exp/_typing.py +158 -0
- sindy_exp/_utils.py +381 -0
- sindy_exp/addl_attractors.json +91 -0
- sindy_exp-0.2.0.dist-info/METADATA +111 -0
- sindy_exp-0.2.0.dist-info/RECORD +14 -0
- sindy_exp-0.2.0.dist-info/WHEEL +5 -0
- sindy_exp-0.2.0.dist-info/licenses/LICENSE +21 -0
- sindy_exp-0.2.0.dist-info/top_level.txt +1 -0
sindy_exp/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from ._data import ODE_CLASSES, gen_data
|
|
2
|
+
from ._odes import fit_eval, plot_ode_panel
|
|
3
|
+
from ._plotting import (
|
|
4
|
+
COLOR,
|
|
5
|
+
compare_coefficient_plots_from_dicts,
|
|
6
|
+
plot_coefficients,
|
|
7
|
+
plot_test_trajectory,
|
|
8
|
+
plot_training_data,
|
|
9
|
+
)
|
|
10
|
+
from ._typing import DynamicsTrialData, ProbData
|
|
11
|
+
from ._utils import coeff_metrics, integration_metrics, pred_metrics
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"gen_data",
|
|
15
|
+
"fit_eval",
|
|
16
|
+
"ProbData",
|
|
17
|
+
"DynamicsTrialData",
|
|
18
|
+
"coeff_metrics",
|
|
19
|
+
"pred_metrics",
|
|
20
|
+
"integration_metrics",
|
|
21
|
+
"ODE_CLASSES",
|
|
22
|
+
"plot_ode_panel",
|
|
23
|
+
"plot_coefficients",
|
|
24
|
+
"compare_coefficient_plots_from_dicts",
|
|
25
|
+
"plot_test_trajectory",
|
|
26
|
+
"plot_training_data",
|
|
27
|
+
"COLOR",
|
|
28
|
+
]
|
sindy_exp/_data.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Any, Callable, Optional, cast
|
|
3
|
+
|
|
4
|
+
import dysts.flows
|
|
5
|
+
import dysts.systems
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy
|
|
8
|
+
|
|
9
|
+
from ._dysts_to_sympy import dynsys_to_sympy
|
|
10
|
+
from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
|
|
11
|
+
from ._plotting import plot_training_data
|
|
12
|
+
from ._typing import Float1D, ProbData
|
|
13
|
+
from ._utils import _sympy_expr_to_feat_coeff
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import jax
|
|
17
|
+
|
|
18
|
+
from ._diffrax_solver import _gen_data_jax
|
|
19
|
+
except ImportError:
|
|
20
|
+
raise
|
|
21
|
+
|
|
22
|
+
INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
|
|
23
|
+
MOD_LOG = getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
ODE_CLASSES = {
|
|
26
|
+
klass.lower(): getattr(dysts.flows, klass)
|
|
27
|
+
for klass in dysts.systems.get_attractor_list()
|
|
28
|
+
}
|
|
29
|
+
ODE_CLASSES.update(
|
|
30
|
+
{
|
|
31
|
+
"lotkavolterra": LotkaVolterra,
|
|
32
|
+
"sho": SHO,
|
|
33
|
+
"cubicho": CubicHO,
|
|
34
|
+
"hopf": Hopf,
|
|
35
|
+
"vanderpol": VanDerPol,
|
|
36
|
+
"kinematics": Kinematics,
|
|
37
|
+
}
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def gen_data(
|
|
42
|
+
system: str,
|
|
43
|
+
seed: int,
|
|
44
|
+
n_trajectories: int = 1,
|
|
45
|
+
ic_stdev: float = 3,
|
|
46
|
+
noise_abs: Optional[float] = None,
|
|
47
|
+
noise_rel: Optional[float] = None,
|
|
48
|
+
dt: float = 0.01,
|
|
49
|
+
t_end: float = 10,
|
|
50
|
+
display: bool = False,
|
|
51
|
+
array_namespace: str = "numpy",
|
|
52
|
+
) -> dict[str, Any]:
|
|
53
|
+
"""Generate random training and test data
|
|
54
|
+
|
|
55
|
+
An Experiment step according to the mitosis experiment runner.
|
|
56
|
+
Note that test data has no noise.
|
|
57
|
+
|
|
58
|
+
Arguments:
|
|
59
|
+
system: the system to integrate
|
|
60
|
+
seed (int): the random seed for number generation
|
|
61
|
+
n_trajectories (int): number of trajectories of training data
|
|
62
|
+
ic_stdev (float): standard deviation for generating initial conditions
|
|
63
|
+
noise_abs (float): measurement noise standard deviation.
|
|
64
|
+
Defaults to .1 if noise_rel is None.
|
|
65
|
+
noise_rel (float): measurement noise-to-signal power ratio.
|
|
66
|
+
Either noise_abs or noise_rel must be None. Defaults to
|
|
67
|
+
None.
|
|
68
|
+
dt: time step for sample
|
|
69
|
+
t_end: end time of simulation
|
|
70
|
+
display: Whether to display graphics of generated data.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
dictionary of data and descriptive information
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
dyst_sys = ODE_CLASSES[system.lower()]()
|
|
77
|
+
except KeyError as e:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Unknown system {system}. Check {__name__}.ODE_CLASSES"
|
|
80
|
+
) from e
|
|
81
|
+
input_features, sp_expr, sp_lambda = dynsys_to_sympy(dyst_sys)
|
|
82
|
+
coeff_true = _sympy_expr_to_feat_coeff(sp_expr)
|
|
83
|
+
rhsfunc = lambda t, X: dyst_sys.rhs(X, t) # noqa: E731
|
|
84
|
+
try:
|
|
85
|
+
x0_center = dyst_sys.ic
|
|
86
|
+
except KeyError:
|
|
87
|
+
x0_center = np.zeros((len(input_features)), dtype=np.float64)
|
|
88
|
+
try:
|
|
89
|
+
nonnegative = getattr(dyst_sys, "nonnegative", False)
|
|
90
|
+
except KeyError:
|
|
91
|
+
nonnegative = False
|
|
92
|
+
if noise_abs is not None and noise_rel is not None:
|
|
93
|
+
raise ValueError("Cannot specify both noise_abs and noise_rel")
|
|
94
|
+
elif noise_abs is None and noise_rel is None:
|
|
95
|
+
noise_abs = 0.1
|
|
96
|
+
|
|
97
|
+
MOD_LOG.info(f"Generating {n_trajectories} trajectories of f{system}")
|
|
98
|
+
prob_data_list: list[ProbData] = []
|
|
99
|
+
if array_namespace == "numpy":
|
|
100
|
+
feature_names = [feat.name for feat in input_features]
|
|
101
|
+
for _ in range(n_trajectories):
|
|
102
|
+
seed += 1
|
|
103
|
+
prob = _gen_data(
|
|
104
|
+
rhsfunc,
|
|
105
|
+
feature_names,
|
|
106
|
+
seed,
|
|
107
|
+
x0_center=x0_center,
|
|
108
|
+
nonnegative=nonnegative,
|
|
109
|
+
ic_stdev=ic_stdev,
|
|
110
|
+
noise_abs=noise_abs,
|
|
111
|
+
noise_rel=noise_rel,
|
|
112
|
+
dt=dt,
|
|
113
|
+
t_end=t_end,
|
|
114
|
+
)
|
|
115
|
+
prob_data_list.append(prob)
|
|
116
|
+
elif array_namespace == "jax":
|
|
117
|
+
try:
|
|
118
|
+
globals()["_gen_data_jax"]
|
|
119
|
+
except KeyError:
|
|
120
|
+
raise ImportError(
|
|
121
|
+
"jax data generation requested but diffrax or sympy2jax not"
|
|
122
|
+
" installed"
|
|
123
|
+
)
|
|
124
|
+
this_seed = jax.random.PRNGKey(seed)
|
|
125
|
+
for _ in range(n_trajectories):
|
|
126
|
+
this_seed, _ = jax.random.split(this_seed)
|
|
127
|
+
prob = _gen_data_jax(
|
|
128
|
+
sp_expr,
|
|
129
|
+
input_features,
|
|
130
|
+
this_seed,
|
|
131
|
+
x0_center=x0_center,
|
|
132
|
+
nonnegative=nonnegative,
|
|
133
|
+
ic_stdev=ic_stdev,
|
|
134
|
+
noise_abs=noise_abs,
|
|
135
|
+
noise_rel=noise_rel,
|
|
136
|
+
dt=dt,
|
|
137
|
+
t_end=t_end,
|
|
138
|
+
)
|
|
139
|
+
prob_data_list.append(prob)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Unknown array_namespace {array_namespace}. Must be 'numpy' or 'jax'"
|
|
143
|
+
)
|
|
144
|
+
if display and prob_data_list:
|
|
145
|
+
sample = prob_data_list[0]
|
|
146
|
+
figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
|
|
147
|
+
figs[0].suptitle("Sample Trajectory")
|
|
148
|
+
return {
|
|
149
|
+
"data": {"trajectories": prob_data_list, "coeff_true": coeff_true},
|
|
150
|
+
"main": f"{n_trajectories} trajectories of {rhsfunc}",
|
|
151
|
+
"metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _gen_data(
|
|
156
|
+
rhs_func: Callable,
|
|
157
|
+
input_features: list[str],
|
|
158
|
+
seed: Optional[int],
|
|
159
|
+
x0_center: Float1D,
|
|
160
|
+
ic_stdev: float,
|
|
161
|
+
noise_abs: Optional[float],
|
|
162
|
+
noise_rel: Optional[float],
|
|
163
|
+
nonnegative: bool,
|
|
164
|
+
dt: float,
|
|
165
|
+
t_end: float,
|
|
166
|
+
) -> ProbData:
|
|
167
|
+
rng = np.random.default_rng(seed)
|
|
168
|
+
t_train = np.arange(0, t_end, dt)
|
|
169
|
+
t_train_span = (t_train[0], t_train[-1])
|
|
170
|
+
if nonnegative:
|
|
171
|
+
shape = ((x0_center + 1) / ic_stdev) ** 2
|
|
172
|
+
scale = ic_stdev**2 / (x0_center + 1)
|
|
173
|
+
x0 = np.array([rng.gamma(k, theta) for k, theta in zip(shape, scale)])
|
|
174
|
+
else:
|
|
175
|
+
x0 = ic_stdev * rng.standard_normal(len(input_features)) + x0_center
|
|
176
|
+
x_train = scipy.integrate.solve_ivp(
|
|
177
|
+
rhs_func,
|
|
178
|
+
t_train_span,
|
|
179
|
+
x0,
|
|
180
|
+
t_eval=t_train,
|
|
181
|
+
**INTEGRATOR_KEYWORDS,
|
|
182
|
+
).y.T
|
|
183
|
+
|
|
184
|
+
x_train_true = np.copy(x_train)
|
|
185
|
+
x_train_true_dot = np.array([rhs_func(0, xi) for xi in x_train_true])
|
|
186
|
+
if noise_rel is not None:
|
|
187
|
+
noise_abs = np.sqrt(_signal_avg_power(x_train) * noise_rel)
|
|
188
|
+
x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
|
|
189
|
+
|
|
190
|
+
return ProbData(
|
|
191
|
+
dt, t_train, x_train, x_train_true, x_train_true_dot, input_features
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _max_amplitude(signal: np.ndarray, axis: int) -> float:
|
|
196
|
+
return np.abs(scipy.fft.rfft(signal, axis=axis)[1:]).max() / np.sqrt(
|
|
197
|
+
signal.shape[axis]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _signal_avg_power(signal: np.ndarray) -> float:
|
|
202
|
+
return np.square(signal).mean()
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import diffrax
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import sympy2jax
|
|
7
|
+
from sympy import Expr, Symbol
|
|
8
|
+
|
|
9
|
+
from ._typing import ProbData
|
|
10
|
+
|
|
11
|
+
jax.config.update("jax_enable_x64", True)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _gen_data_jax(
|
|
15
|
+
exprs: list[Expr],
|
|
16
|
+
input_features: list[Symbol],
|
|
17
|
+
seed: jax.Array,
|
|
18
|
+
x0_center: jax.Array,
|
|
19
|
+
ic_stdev: float,
|
|
20
|
+
noise_abs: Optional[float],
|
|
21
|
+
noise_rel: Optional[float],
|
|
22
|
+
nonnegative: bool,
|
|
23
|
+
dt: float,
|
|
24
|
+
t_end: float,
|
|
25
|
+
) -> ProbData:
|
|
26
|
+
rhstree = sympy2jax.SymbolicModule(exprs)
|
|
27
|
+
|
|
28
|
+
def ode_sys(t, state, args):
|
|
29
|
+
return jnp.asarray(
|
|
30
|
+
rhstree(
|
|
31
|
+
**{
|
|
32
|
+
str(x_sym): state_i
|
|
33
|
+
for x_sym, state_i in zip(input_features, state, strict=True)
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
term = diffrax.ODETerm(ode_sys)
|
|
39
|
+
solver = diffrax.Tsit5()
|
|
40
|
+
save_at = diffrax.SaveAt(ts=jnp.arange(0, t_end, dt), dense=True)
|
|
41
|
+
|
|
42
|
+
# Random initialization
|
|
43
|
+
key, subkey = jax.random.split(seed)
|
|
44
|
+
t_train = jnp.arange(0, t_end, dt)
|
|
45
|
+
if nonnegative:
|
|
46
|
+
shape = ((x0_center + 1) / ic_stdev) ** 2
|
|
47
|
+
scale = ic_stdev**2 / (x0_center + 1)
|
|
48
|
+
x0 = jnp.array(
|
|
49
|
+
jax.random.gamma(subkey, k) * theta for k, theta in zip(shape, scale)
|
|
50
|
+
).T
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
x0 = ic_stdev * jax.random.normal(subkey, (len(input_features),)) + x0_center
|
|
54
|
+
key, subkey = jax.random.split(key)
|
|
55
|
+
|
|
56
|
+
# IVPs
|
|
57
|
+
sol = diffrax.diffeqsolve(
|
|
58
|
+
term,
|
|
59
|
+
solver,
|
|
60
|
+
t0=0,
|
|
61
|
+
t1=t_end,
|
|
62
|
+
dt0=dt, # Initial step size
|
|
63
|
+
y0=x0,
|
|
64
|
+
args=(),
|
|
65
|
+
saveat=save_at,
|
|
66
|
+
max_steps=int(10 * (t_end - 0) / dt),
|
|
67
|
+
)
|
|
68
|
+
x_train_true: jax.Array = sol.ys # type: ignore
|
|
69
|
+
|
|
70
|
+
# Measurement noise
|
|
71
|
+
if noise_abs is None:
|
|
72
|
+
assert noise_rel is not None # force type narrowing
|
|
73
|
+
noise_abs = float(jnp.sqrt(_signal_avg_power(x_train_true)) * noise_rel)
|
|
74
|
+
|
|
75
|
+
x_train = x_train_true + jax.random.normal(key, x_train_true.shape) * noise_abs
|
|
76
|
+
|
|
77
|
+
# True Derivatives
|
|
78
|
+
x_train_true_dot = jnp.array([ode_sys(0, xi, None) for xi in x_train_true])
|
|
79
|
+
|
|
80
|
+
stringy_features = [sym.name for sym in input_features]
|
|
81
|
+
return ProbData(
|
|
82
|
+
dt, t_train, x_train, x_train_true, x_train_true_dot, stringy_features, sol
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _signal_avg_power(signal: jax.Array) -> jax.Array:
|
|
87
|
+
return jnp.square(signal).mean()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
## % # noqa:E266
|
|
91
|
+
if __name__ == "__main__":
|
|
92
|
+
# Debug example
|
|
93
|
+
from sindy_exp._data import gen_data
|
|
94
|
+
|
|
95
|
+
data_dict = gen_data(
|
|
96
|
+
"valliselnino",
|
|
97
|
+
seed=50,
|
|
98
|
+
n_trajectories=1,
|
|
99
|
+
ic_stdev=3,
|
|
100
|
+
noise_rel=0.1,
|
|
101
|
+
display=True,
|
|
102
|
+
array_namespace="jax",
|
|
103
|
+
)
|
|
104
|
+
print(data_dict["input_features"])
|