sindy-exp 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sindy_exp/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ from ._data import ODE_CLASSES, gen_data
2
+ from ._odes import fit_eval, plot_ode_panel
3
+ from ._plotting import (
4
+ COLOR,
5
+ compare_coefficient_plots_from_dicts,
6
+ plot_coefficients,
7
+ plot_test_trajectory,
8
+ plot_training_data,
9
+ )
10
+ from ._typing import DynamicsTrialData, ProbData
11
+ from ._utils import coeff_metrics, integration_metrics, pred_metrics
12
+
13
+ __all__ = [
14
+ "gen_data",
15
+ "fit_eval",
16
+ "ProbData",
17
+ "DynamicsTrialData",
18
+ "coeff_metrics",
19
+ "pred_metrics",
20
+ "integration_metrics",
21
+ "ODE_CLASSES",
22
+ "plot_ode_panel",
23
+ "plot_coefficients",
24
+ "compare_coefficient_plots_from_dicts",
25
+ "plot_test_trajectory",
26
+ "plot_training_data",
27
+ "COLOR",
28
+ ]
sindy_exp/_data.py ADDED
@@ -0,0 +1,202 @@
1
+ from logging import getLogger
2
+ from typing import Any, Callable, Optional, cast
3
+
4
+ import dysts.flows
5
+ import dysts.systems
6
+ import numpy as np
7
+ import scipy
8
+
9
+ from ._dysts_to_sympy import dynsys_to_sympy
10
+ from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
11
+ from ._plotting import plot_training_data
12
+ from ._typing import Float1D, ProbData
13
+ from ._utils import _sympy_expr_to_feat_coeff
14
+
15
+ try:
16
+ import jax
17
+
18
+ from ._diffrax_solver import _gen_data_jax
19
+ except ImportError:
20
+ raise
21
+
22
+ INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
23
+ MOD_LOG = getLogger(__name__)
24
+
25
+ ODE_CLASSES = {
26
+ klass.lower(): getattr(dysts.flows, klass)
27
+ for klass in dysts.systems.get_attractor_list()
28
+ }
29
+ ODE_CLASSES.update(
30
+ {
31
+ "lotkavolterra": LotkaVolterra,
32
+ "sho": SHO,
33
+ "cubicho": CubicHO,
34
+ "hopf": Hopf,
35
+ "vanderpol": VanDerPol,
36
+ "kinematics": Kinematics,
37
+ }
38
+ )
39
+
40
+
41
+ def gen_data(
42
+ system: str,
43
+ seed: int,
44
+ n_trajectories: int = 1,
45
+ ic_stdev: float = 3,
46
+ noise_abs: Optional[float] = None,
47
+ noise_rel: Optional[float] = None,
48
+ dt: float = 0.01,
49
+ t_end: float = 10,
50
+ display: bool = False,
51
+ array_namespace: str = "numpy",
52
+ ) -> dict[str, Any]:
53
+ """Generate random training and test data
54
+
55
+ An Experiment step according to the mitosis experiment runner.
56
+ Note that test data has no noise.
57
+
58
+ Arguments:
59
+ system: the system to integrate
60
+ seed (int): the random seed for number generation
61
+ n_trajectories (int): number of trajectories of training data
62
+ ic_stdev (float): standard deviation for generating initial conditions
63
+ noise_abs (float): measurement noise standard deviation.
64
+ Defaults to .1 if noise_rel is None.
65
+ noise_rel (float): measurement noise-to-signal power ratio.
66
+ Either noise_abs or noise_rel must be None. Defaults to
67
+ None.
68
+ dt: time step for sample
69
+ t_end: end time of simulation
70
+ display: Whether to display graphics of generated data.
71
+
72
+ Returns:
73
+ dictionary of data and descriptive information
74
+ """
75
+ try:
76
+ dyst_sys = ODE_CLASSES[system.lower()]()
77
+ except KeyError as e:
78
+ raise ValueError(
79
+ f"Unknown system {system}. Check {__name__}.ODE_CLASSES"
80
+ ) from e
81
+ input_features, sp_expr, sp_lambda = dynsys_to_sympy(dyst_sys)
82
+ coeff_true = _sympy_expr_to_feat_coeff(sp_expr)
83
+ rhsfunc = lambda t, X: dyst_sys.rhs(X, t) # noqa: E731
84
+ try:
85
+ x0_center = dyst_sys.ic
86
+ except KeyError:
87
+ x0_center = np.zeros((len(input_features)), dtype=np.float64)
88
+ try:
89
+ nonnegative = getattr(dyst_sys, "nonnegative", False)
90
+ except KeyError:
91
+ nonnegative = False
92
+ if noise_abs is not None and noise_rel is not None:
93
+ raise ValueError("Cannot specify both noise_abs and noise_rel")
94
+ elif noise_abs is None and noise_rel is None:
95
+ noise_abs = 0.1
96
+
97
+ MOD_LOG.info(f"Generating {n_trajectories} trajectories of f{system}")
98
+ prob_data_list: list[ProbData] = []
99
+ if array_namespace == "numpy":
100
+ feature_names = [feat.name for feat in input_features]
101
+ for _ in range(n_trajectories):
102
+ seed += 1
103
+ prob = _gen_data(
104
+ rhsfunc,
105
+ feature_names,
106
+ seed,
107
+ x0_center=x0_center,
108
+ nonnegative=nonnegative,
109
+ ic_stdev=ic_stdev,
110
+ noise_abs=noise_abs,
111
+ noise_rel=noise_rel,
112
+ dt=dt,
113
+ t_end=t_end,
114
+ )
115
+ prob_data_list.append(prob)
116
+ elif array_namespace == "jax":
117
+ try:
118
+ globals()["_gen_data_jax"]
119
+ except KeyError:
120
+ raise ImportError(
121
+ "jax data generation requested but diffrax or sympy2jax not"
122
+ " installed"
123
+ )
124
+ this_seed = jax.random.PRNGKey(seed)
125
+ for _ in range(n_trajectories):
126
+ this_seed, _ = jax.random.split(this_seed)
127
+ prob = _gen_data_jax(
128
+ sp_expr,
129
+ input_features,
130
+ this_seed,
131
+ x0_center=x0_center,
132
+ nonnegative=nonnegative,
133
+ ic_stdev=ic_stdev,
134
+ noise_abs=noise_abs,
135
+ noise_rel=noise_rel,
136
+ dt=dt,
137
+ t_end=t_end,
138
+ )
139
+ prob_data_list.append(prob)
140
+ else:
141
+ raise ValueError(
142
+ f"Unknown array_namespace {array_namespace}. Must be 'numpy' or 'jax'"
143
+ )
144
+ if display and prob_data_list:
145
+ sample = prob_data_list[0]
146
+ figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
147
+ figs[0].suptitle("Sample Trajectory")
148
+ return {
149
+ "data": {"trajectories": prob_data_list, "coeff_true": coeff_true},
150
+ "main": f"{n_trajectories} trajectories of {rhsfunc}",
151
+ "metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
152
+ }
153
+
154
+
155
+ def _gen_data(
156
+ rhs_func: Callable,
157
+ input_features: list[str],
158
+ seed: Optional[int],
159
+ x0_center: Float1D,
160
+ ic_stdev: float,
161
+ noise_abs: Optional[float],
162
+ noise_rel: Optional[float],
163
+ nonnegative: bool,
164
+ dt: float,
165
+ t_end: float,
166
+ ) -> ProbData:
167
+ rng = np.random.default_rng(seed)
168
+ t_train = np.arange(0, t_end, dt)
169
+ t_train_span = (t_train[0], t_train[-1])
170
+ if nonnegative:
171
+ shape = ((x0_center + 1) / ic_stdev) ** 2
172
+ scale = ic_stdev**2 / (x0_center + 1)
173
+ x0 = np.array([rng.gamma(k, theta) for k, theta in zip(shape, scale)])
174
+ else:
175
+ x0 = ic_stdev * rng.standard_normal(len(input_features)) + x0_center
176
+ x_train = scipy.integrate.solve_ivp(
177
+ rhs_func,
178
+ t_train_span,
179
+ x0,
180
+ t_eval=t_train,
181
+ **INTEGRATOR_KEYWORDS,
182
+ ).y.T
183
+
184
+ x_train_true = np.copy(x_train)
185
+ x_train_true_dot = np.array([rhs_func(0, xi) for xi in x_train_true])
186
+ if noise_rel is not None:
187
+ noise_abs = np.sqrt(_signal_avg_power(x_train) * noise_rel)
188
+ x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
189
+
190
+ return ProbData(
191
+ dt, t_train, x_train, x_train_true, x_train_true_dot, input_features
192
+ )
193
+
194
+
195
+ def _max_amplitude(signal: np.ndarray, axis: int) -> float:
196
+ return np.abs(scipy.fft.rfft(signal, axis=axis)[1:]).max() / np.sqrt(
197
+ signal.shape[axis]
198
+ )
199
+
200
+
201
+ def _signal_avg_power(signal: np.ndarray) -> float:
202
+ return np.square(signal).mean()
@@ -0,0 +1,104 @@
1
+ from typing import Optional
2
+
3
+ import diffrax
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import sympy2jax
7
+ from sympy import Expr, Symbol
8
+
9
+ from ._typing import ProbData
10
+
11
+ jax.config.update("jax_enable_x64", True)
12
+
13
+
14
+ def _gen_data_jax(
15
+ exprs: list[Expr],
16
+ input_features: list[Symbol],
17
+ seed: jax.Array,
18
+ x0_center: jax.Array,
19
+ ic_stdev: float,
20
+ noise_abs: Optional[float],
21
+ noise_rel: Optional[float],
22
+ nonnegative: bool,
23
+ dt: float,
24
+ t_end: float,
25
+ ) -> ProbData:
26
+ rhstree = sympy2jax.SymbolicModule(exprs)
27
+
28
+ def ode_sys(t, state, args):
29
+ return jnp.asarray(
30
+ rhstree(
31
+ **{
32
+ str(x_sym): state_i
33
+ for x_sym, state_i in zip(input_features, state, strict=True)
34
+ }
35
+ )
36
+ )
37
+
38
+ term = diffrax.ODETerm(ode_sys)
39
+ solver = diffrax.Tsit5()
40
+ save_at = diffrax.SaveAt(ts=jnp.arange(0, t_end, dt), dense=True)
41
+
42
+ # Random initialization
43
+ key, subkey = jax.random.split(seed)
44
+ t_train = jnp.arange(0, t_end, dt)
45
+ if nonnegative:
46
+ shape = ((x0_center + 1) / ic_stdev) ** 2
47
+ scale = ic_stdev**2 / (x0_center + 1)
48
+ x0 = jnp.array(
49
+ jax.random.gamma(subkey, k) * theta for k, theta in zip(shape, scale)
50
+ ).T
51
+
52
+ else:
53
+ x0 = ic_stdev * jax.random.normal(subkey, (len(input_features),)) + x0_center
54
+ key, subkey = jax.random.split(key)
55
+
56
+ # IVPs
57
+ sol = diffrax.diffeqsolve(
58
+ term,
59
+ solver,
60
+ t0=0,
61
+ t1=t_end,
62
+ dt0=dt, # Initial step size
63
+ y0=x0,
64
+ args=(),
65
+ saveat=save_at,
66
+ max_steps=int(10 * (t_end - 0) / dt),
67
+ )
68
+ x_train_true: jax.Array = sol.ys # type: ignore
69
+
70
+ # Measurement noise
71
+ if noise_abs is None:
72
+ assert noise_rel is not None # force type narrowing
73
+ noise_abs = float(jnp.sqrt(_signal_avg_power(x_train_true)) * noise_rel)
74
+
75
+ x_train = x_train_true + jax.random.normal(key, x_train_true.shape) * noise_abs
76
+
77
+ # True Derivatives
78
+ x_train_true_dot = jnp.array([ode_sys(0, xi, None) for xi in x_train_true])
79
+
80
+ stringy_features = [sym.name for sym in input_features]
81
+ return ProbData(
82
+ dt, t_train, x_train, x_train_true, x_train_true_dot, stringy_features, sol
83
+ )
84
+
85
+
86
+ def _signal_avg_power(signal: jax.Array) -> jax.Array:
87
+ return jnp.square(signal).mean()
88
+
89
+
90
+ ## % # noqa:E266
91
+ if __name__ == "__main__":
92
+ # Debug example
93
+ from sindy_exp._data import gen_data
94
+
95
+ data_dict = gen_data(
96
+ "valliselnino",
97
+ seed=50,
98
+ n_trajectories=1,
99
+ ic_stdev=3,
100
+ noise_rel=0.1,
101
+ display=True,
102
+ array_namespace="jax",
103
+ )
104
+ print(data_dict["input_features"])