sindy-exp 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sindy_exp/__init__.py +2 -1
- sindy_exp/_data.py +120 -16
- sindy_exp/_diffrax_solver.py +0 -17
- sindy_exp/_odes.py +28 -112
- sindy_exp/_typing.py +9 -0
- sindy_exp/_utils.py +10 -51
- sindy_exp/py.typed +0 -0
- {sindy_exp-0.2.0.dist-info → sindy_exp-0.2.2.dist-info}/METADATA +1 -1
- sindy_exp-0.2.2.dist-info/RECORD +15 -0
- sindy_exp-0.2.0.dist-info/RECORD +0 -14
- {sindy_exp-0.2.0.dist-info → sindy_exp-0.2.2.dist-info}/WHEEL +0 -0
- {sindy_exp-0.2.0.dist-info → sindy_exp-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {sindy_exp-0.2.0.dist-info → sindy_exp-0.2.2.dist-info}/top_level.txt +0 -0
sindy_exp/__init__.py
CHANGED
|
@@ -7,7 +7,7 @@ from ._plotting import (
|
|
|
7
7
|
plot_test_trajectory,
|
|
8
8
|
plot_training_data,
|
|
9
9
|
)
|
|
10
|
-
from ._typing import DynamicsTrialData, ProbData
|
|
10
|
+
from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData
|
|
11
11
|
from ._utils import coeff_metrics, integration_metrics, pred_metrics
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
|
15
15
|
"fit_eval",
|
|
16
16
|
"ProbData",
|
|
17
17
|
"DynamicsTrialData",
|
|
18
|
+
"FullDynamicsTrialData",
|
|
18
19
|
"coeff_metrics",
|
|
19
20
|
"pred_metrics",
|
|
20
21
|
"integration_metrics",
|
sindy_exp/_data.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
|
+
from importlib import resources
|
|
1
2
|
from logging import getLogger
|
|
2
|
-
from typing import
|
|
3
|
+
from typing import Callable, Optional, cast
|
|
3
4
|
|
|
4
5
|
import dysts.flows
|
|
5
6
|
import dysts.systems
|
|
6
7
|
import numpy as np
|
|
7
8
|
import scipy
|
|
9
|
+
import sympy as sp
|
|
10
|
+
from dysts.base import DynSys
|
|
8
11
|
|
|
9
12
|
from ._dysts_to_sympy import dynsys_to_sympy
|
|
10
|
-
from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
|
|
11
13
|
from ._plotting import plot_training_data
|
|
12
|
-
from ._typing import Float1D, ProbData
|
|
14
|
+
from ._typing import ExperimentResult, Float1D, ProbData
|
|
13
15
|
from ._utils import _sympy_expr_to_feat_coeff
|
|
14
16
|
|
|
15
17
|
try:
|
|
@@ -17,25 +19,16 @@ try:
|
|
|
17
19
|
|
|
18
20
|
from ._diffrax_solver import _gen_data_jax
|
|
19
21
|
except ImportError:
|
|
20
|
-
|
|
22
|
+
pass
|
|
21
23
|
|
|
22
24
|
INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
|
|
23
25
|
MOD_LOG = getLogger(__name__)
|
|
26
|
+
LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
|
|
24
27
|
|
|
25
28
|
ODE_CLASSES = {
|
|
26
29
|
klass.lower(): getattr(dysts.flows, klass)
|
|
27
30
|
for klass in dysts.systems.get_attractor_list()
|
|
28
31
|
}
|
|
29
|
-
ODE_CLASSES.update(
|
|
30
|
-
{
|
|
31
|
-
"lotkavolterra": LotkaVolterra,
|
|
32
|
-
"sho": SHO,
|
|
33
|
-
"cubicho": CubicHO,
|
|
34
|
-
"hopf": Hopf,
|
|
35
|
-
"vanderpol": VanDerPol,
|
|
36
|
-
"kinematics": Kinematics,
|
|
37
|
-
}
|
|
38
|
-
)
|
|
39
32
|
|
|
40
33
|
|
|
41
34
|
def gen_data(
|
|
@@ -49,7 +42,7 @@ def gen_data(
|
|
|
49
42
|
t_end: float = 10,
|
|
50
43
|
display: bool = False,
|
|
51
44
|
array_namespace: str = "numpy",
|
|
52
|
-
) -> dict[
|
|
45
|
+
) -> ExperimentResult[tuple[list[ProbData], list[dict[sp.Expr, float]]]]:
|
|
53
46
|
"""Generate random training and test data
|
|
54
47
|
|
|
55
48
|
An Experiment step according to the mitosis experiment runner.
|
|
@@ -146,7 +139,7 @@ def gen_data(
|
|
|
146
139
|
figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
|
|
147
140
|
figs[0].suptitle("Sample Trajectory")
|
|
148
141
|
return {
|
|
149
|
-
"data":
|
|
142
|
+
"data": (prob_data_list, coeff_true),
|
|
150
143
|
"main": f"{n_trajectories} trajectories of {rhsfunc}",
|
|
151
144
|
"metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
|
|
152
145
|
}
|
|
@@ -200,3 +193,114 @@ def _max_amplitude(signal: np.ndarray, axis: int) -> float:
|
|
|
200
193
|
|
|
201
194
|
def _signal_avg_power(signal: np.ndarray) -> float:
|
|
202
195
|
return np.square(signal).mean()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _register_dyst(klass: type[DynSys]) -> type[DynSys]:
|
|
199
|
+
"""Register a custom dysts DynSys class for use in sindy_exp data generation."""
|
|
200
|
+
ODE_CLASSES[klass.__name__.lower()] = klass
|
|
201
|
+
return klass
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@_register_dyst
|
|
205
|
+
class LotkaVolterra(DynSys):
|
|
206
|
+
"""Lotka-Volterra (predator-prey) dynamical system."""
|
|
207
|
+
|
|
208
|
+
nonnegative = True
|
|
209
|
+
|
|
210
|
+
def __init__(self):
|
|
211
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
|
|
215
|
+
"""LV dynamics
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
x: prey population
|
|
219
|
+
y: predator population
|
|
220
|
+
t: time (ignored, since autonomous)
|
|
221
|
+
alpha: prey growth rate
|
|
222
|
+
beta: predation rate
|
|
223
|
+
delta: predator reproduction rate
|
|
224
|
+
gamma: predator death rate
|
|
225
|
+
"""
|
|
226
|
+
dxdt = alpha * x - beta * x * y
|
|
227
|
+
dydt = delta * x * y - gamma * y
|
|
228
|
+
return np.array([dxdt, dydt])
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@_register_dyst
|
|
232
|
+
class Hopf(DynSys):
|
|
233
|
+
"""Hopf normal form dynamical system."""
|
|
234
|
+
|
|
235
|
+
def __init__(self):
|
|
236
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
|
|
240
|
+
dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
|
|
241
|
+
dydt = omega * x + mu * y - A * (x**2 * y + y**3)
|
|
242
|
+
return np.array([dxdt, dydt])
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@_register_dyst
|
|
246
|
+
class SHO(DynSys):
|
|
247
|
+
"""Linear damped simple harmonic oscillator"""
|
|
248
|
+
|
|
249
|
+
def __init__(self):
|
|
250
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
251
|
+
|
|
252
|
+
@staticmethod
|
|
253
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
254
|
+
dxdt = a * x + b * y
|
|
255
|
+
dydt = c * x + d * y
|
|
256
|
+
return np.array([dxdt, dydt])
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@_register_dyst
|
|
260
|
+
class CubicHO(DynSys):
|
|
261
|
+
"""Cubic damped harmonic oscillator."""
|
|
262
|
+
|
|
263
|
+
def __init__(self):
|
|
264
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
268
|
+
dxdt = a * x**3 + b * y**3
|
|
269
|
+
dydt = c * x**3 + d * y**3
|
|
270
|
+
return np.array([dxdt, dydt])
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@_register_dyst
|
|
274
|
+
class VanDerPol(DynSys):
|
|
275
|
+
"""Van der Pol oscillator.
|
|
276
|
+
|
|
277
|
+
dx/dt = y
|
|
278
|
+
dy/dt = mu * (1 - x^2) * y - x
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(self):
|
|
282
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
|
|
286
|
+
dxdt = x_dot
|
|
287
|
+
dx2dt2 = mu * (1 - x**2) * x_dot - x
|
|
288
|
+
return np.array([dxdt, dx2dt2])
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@_register_dyst
|
|
292
|
+
class Kinematics(DynSys):
|
|
293
|
+
"""One-dimensional kinematics with constant acceleration.
|
|
294
|
+
|
|
295
|
+
dx/dt = v
|
|
296
|
+
dv/dt = a
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
def __init__(self):
|
|
300
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _rhs(x, v, t: float, a) -> np.ndarray:
|
|
304
|
+
dxdt = v
|
|
305
|
+
dvdt = a
|
|
306
|
+
return np.array([dxdt, dvdt])
|
sindy_exp/_diffrax_solver.py
CHANGED
|
@@ -85,20 +85,3 @@ def _gen_data_jax(
|
|
|
85
85
|
|
|
86
86
|
def _signal_avg_power(signal: jax.Array) -> jax.Array:
|
|
87
87
|
return jnp.square(signal).mean()
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
## % # noqa:E266
|
|
91
|
-
if __name__ == "__main__":
|
|
92
|
-
# Debug example
|
|
93
|
-
from sindy_exp._data import gen_data
|
|
94
|
-
|
|
95
|
-
data_dict = gen_data(
|
|
96
|
-
"valliselnino",
|
|
97
|
-
seed=50,
|
|
98
|
-
n_trajectories=1,
|
|
99
|
-
ic_stdev=3,
|
|
100
|
-
noise_rel=0.1,
|
|
101
|
-
display=True,
|
|
102
|
-
array_namespace="jax",
|
|
103
|
-
)
|
|
104
|
-
print(data_dict["input_features"])
|
sindy_exp/_odes.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from importlib import resources
|
|
2
1
|
from logging import getLogger
|
|
3
|
-
from typing import Callable, TypeVar
|
|
2
|
+
from typing import Any, Callable, Literal, TypeVar, cast, overload
|
|
4
3
|
|
|
5
4
|
import matplotlib.pyplot as plt
|
|
6
5
|
import numpy as np
|
|
7
|
-
import pysindy as ps
|
|
8
6
|
import sympy as sp
|
|
9
|
-
from dysts.base import DynSys
|
|
10
7
|
|
|
11
8
|
from ._plotting import (
|
|
12
9
|
compare_coefficient_plots_from_dicts,
|
|
@@ -15,6 +12,7 @@ from ._plotting import (
|
|
|
15
12
|
)
|
|
16
13
|
from ._typing import (
|
|
17
14
|
DynamicsTrialData,
|
|
15
|
+
ExperimentResult,
|
|
18
16
|
FullDynamicsTrialData,
|
|
19
17
|
ProbData,
|
|
20
18
|
SINDyTrialUpdate,
|
|
@@ -41,7 +39,6 @@ metric_ordering = {
|
|
|
41
39
|
T = TypeVar("T", bound=int)
|
|
42
40
|
DType = TypeVar("DType", bound=np.dtype)
|
|
43
41
|
MOD_LOG = getLogger(__name__)
|
|
44
|
-
LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
|
|
45
42
|
|
|
46
43
|
|
|
47
44
|
def _add_forcing(
|
|
@@ -68,112 +65,30 @@ def _add_forcing(
|
|
|
68
65
|
return sum_of_terms
|
|
69
66
|
|
|
70
67
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@staticmethod
|
|
80
|
-
def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
|
|
81
|
-
"""LV dynamics
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
x: prey population
|
|
85
|
-
y: predator population
|
|
86
|
-
t: time (ignored, since autonomous)
|
|
87
|
-
alpha: prey growth rate
|
|
88
|
-
beta: predation rate
|
|
89
|
-
delta: predator reproduction rate
|
|
90
|
-
gamma: predator death rate
|
|
91
|
-
"""
|
|
92
|
-
dxdt = alpha * x - beta * x * y
|
|
93
|
-
dydt = delta * x * y - gamma * y
|
|
94
|
-
return np.array([dxdt, dydt])
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class Hopf(DynSys):
|
|
98
|
-
"""Hopf normal form dynamical system."""
|
|
99
|
-
|
|
100
|
-
def __init__(self):
|
|
101
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
102
|
-
|
|
103
|
-
@staticmethod
|
|
104
|
-
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
|
|
105
|
-
dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
|
|
106
|
-
dydt = omega * x + mu * y - A * (x**2 * y + y**3)
|
|
107
|
-
return np.array([dxdt, dydt])
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
class SHO(DynSys):
|
|
111
|
-
"""Linear damped simple harmonic oscillator"""
|
|
112
|
-
|
|
113
|
-
def __init__(self):
|
|
114
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
115
|
-
|
|
116
|
-
@staticmethod
|
|
117
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
118
|
-
dxdt = a * x + b * y
|
|
119
|
-
dydt = c * x + d * y
|
|
120
|
-
return np.array([dxdt, dydt])
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class CubicHO(DynSys):
|
|
124
|
-
"""Cubic damped harmonic oscillator."""
|
|
125
|
-
|
|
126
|
-
def __init__(self):
|
|
127
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
128
|
-
|
|
129
|
-
@staticmethod
|
|
130
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
131
|
-
dxdt = a * x**3 + b * y**3
|
|
132
|
-
dydt = c * x**3 + d * y**3
|
|
133
|
-
return np.array([dxdt, dydt])
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class VanDerPol(DynSys):
|
|
137
|
-
"""Van der Pol oscillator.
|
|
138
|
-
|
|
139
|
-
dx/dt = y
|
|
140
|
-
dy/dt = mu * (1 - x^2) * y - x
|
|
141
|
-
"""
|
|
142
|
-
|
|
143
|
-
def __init__(self):
|
|
144
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
145
|
-
|
|
146
|
-
@staticmethod
|
|
147
|
-
def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
|
|
148
|
-
dxdt = x_dot
|
|
149
|
-
dx2dt2 = mu * (1 - x**2) * x_dot - x
|
|
150
|
-
return np.array([dxdt, dx2dt2])
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class Kinematics(DynSys):
|
|
154
|
-
"""One-dimensional kinematics with constant acceleration.
|
|
155
|
-
|
|
156
|
-
dx/dt = v
|
|
157
|
-
dv/dt = a
|
|
158
|
-
"""
|
|
68
|
+
@overload
|
|
69
|
+
def fit_eval(
|
|
70
|
+
data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
|
|
71
|
+
model: _BaseSINDy,
|
|
72
|
+
simulations: Literal[False],
|
|
73
|
+
display: bool,
|
|
74
|
+
) -> ExperimentResult[DynamicsTrialData]: ...
|
|
159
75
|
|
|
160
|
-
def __init__(self):
|
|
161
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
162
76
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
77
|
+
@overload
|
|
78
|
+
def fit_eval(
|
|
79
|
+
data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
|
|
80
|
+
model: _BaseSINDy,
|
|
81
|
+
simulations: Literal[True],
|
|
82
|
+
display: bool,
|
|
83
|
+
) -> ExperimentResult[FullDynamicsTrialData]: ...
|
|
168
84
|
|
|
169
85
|
|
|
170
86
|
def fit_eval(
|
|
171
87
|
data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
|
|
172
|
-
model:
|
|
88
|
+
model: Any,
|
|
173
89
|
simulations: bool = True,
|
|
174
90
|
display: bool = True,
|
|
175
|
-
|
|
176
|
-
) -> dict | tuple[dict, DynamicsTrialData | FullDynamicsTrialData]:
|
|
91
|
+
) -> ExperimentResult:
|
|
177
92
|
"""Fit and evaluate a SINDy model on a set of trajectories.
|
|
178
93
|
|
|
179
94
|
Args:
|
|
@@ -184,11 +99,8 @@ def fit_eval(
|
|
|
184
99
|
model: A SINDy-like model implementing the _BaseSINDy protocol.
|
|
185
100
|
simulations: Whether to run forward simulations for evaluation.
|
|
186
101
|
display: Whether to generate plots as part of evaluation.
|
|
187
|
-
return_all: If True, return a dictionary containing metrics and the
|
|
188
|
-
assembled DynamicsTrialData; otherwise return only the metrics
|
|
189
|
-
dictionary.
|
|
190
102
|
"""
|
|
191
|
-
|
|
103
|
+
model = cast(_BaseSINDy, model)
|
|
192
104
|
trajectories, true_equations = data
|
|
193
105
|
input_features = trajectories[0].input_features
|
|
194
106
|
|
|
@@ -198,12 +110,16 @@ def fit_eval(
|
|
|
198
110
|
|
|
199
111
|
MOD_LOG.info(f"Fitting a model: {model}")
|
|
200
112
|
coeff_true_dicts, coeff_est_dicts = unionize_coeff_dicts(model, true_equations)
|
|
201
|
-
|
|
113
|
+
|
|
114
|
+
# Special workaround for pysindy's legacy WeakPDELibrary
|
|
115
|
+
if hasattr(model.feature_library, "K"):
|
|
202
116
|
# WeakPDE library fails to simulate, so insert nonweak library
|
|
203
117
|
# to Pipeline and SINDy model.
|
|
204
118
|
inner_lib = model.feature_library.function_library
|
|
205
119
|
model.feature_library = inner_lib # type: ignore # TODO: Fix in pysindy
|
|
206
|
-
|
|
120
|
+
|
|
121
|
+
# Special workaround for pysindy's bad (soon to be legacy) differentiation API
|
|
122
|
+
if hasattr(model, "differentiation_method") and hasattr(
|
|
207
123
|
model.differentiation_method, "smoothed_x_"
|
|
208
124
|
):
|
|
209
125
|
smooth_x = []
|
|
@@ -212,6 +128,7 @@ def fit_eval(
|
|
|
212
128
|
smooth_x.append(model.differentiation_method.smoothed_x_)
|
|
213
129
|
else: # using WeakPDELibrary
|
|
214
130
|
smooth_x = x_train
|
|
131
|
+
|
|
215
132
|
trial_data = DynamicsTrialData(
|
|
216
133
|
trajectories=trajectories,
|
|
217
134
|
true_equations=coeff_true_dicts,
|
|
@@ -264,9 +181,8 @@ def fit_eval(
|
|
|
264
181
|
figs=(fig_composite, fig_by_coord_1d),
|
|
265
182
|
coord_names=input_features,
|
|
266
183
|
)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
return metrics
|
|
184
|
+
|
|
185
|
+
return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
|
|
270
186
|
|
|
271
187
|
|
|
272
188
|
def plot_ode_panel(trial_data: DynamicsTrialData):
|
sindy_exp/_typing.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import (
|
|
|
7
7
|
NamedTuple,
|
|
8
8
|
Optional,
|
|
9
9
|
Protocol,
|
|
10
|
+
TypedDict,
|
|
10
11
|
TypeVar,
|
|
11
12
|
overload,
|
|
12
13
|
)
|
|
@@ -27,6 +28,14 @@ FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
|
|
|
27
28
|
TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class ExperimentResult[T](TypedDict):
|
|
32
|
+
"""Results from a SINDy ODE experiment."""
|
|
33
|
+
|
|
34
|
+
metrics: float
|
|
35
|
+
data: T
|
|
36
|
+
main: float
|
|
37
|
+
|
|
38
|
+
|
|
30
39
|
class _BaseSINDy(Protocol):
|
|
31
40
|
optimizer: Any
|
|
32
41
|
feature_library: Any
|
sindy_exp/_utils.py
CHANGED
|
@@ -182,14 +182,18 @@ def coeff_metrics(
|
|
|
182
182
|
coefficients[row_ind, col_ind] = est_row[feat]
|
|
183
183
|
|
|
184
184
|
metrics: dict[str, float | np.floating] = {}
|
|
185
|
-
metrics["coeff_precision"] =
|
|
186
|
-
|
|
185
|
+
metrics["coeff_precision"] = float(
|
|
186
|
+
sklearn.metrics.precision_score(
|
|
187
|
+
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
188
|
+
)
|
|
187
189
|
)
|
|
188
|
-
metrics["coeff_recall"] =
|
|
189
|
-
|
|
190
|
+
metrics["coeff_recall"] = float(
|
|
191
|
+
sklearn.metrics.recall_score(
|
|
192
|
+
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
193
|
+
)
|
|
190
194
|
)
|
|
191
|
-
metrics["coeff_f1"] =
|
|
192
|
-
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
195
|
+
metrics["coeff_f1"] = float(
|
|
196
|
+
sklearn.metrics.f1_score(coeff_true.flatten() != 0, coefficients.flatten() != 0)
|
|
193
197
|
)
|
|
194
198
|
metrics["coeff_mse"] = sklearn.metrics.mean_squared_error(
|
|
195
199
|
coeff_true.flatten(), coefficients.flatten()
|
|
@@ -279,51 +283,6 @@ def unionize_coeff_dicts(
|
|
|
279
283
|
return true_aligned, est_aligned
|
|
280
284
|
|
|
281
285
|
|
|
282
|
-
def make_model(
|
|
283
|
-
input_features: list[str],
|
|
284
|
-
dt: float,
|
|
285
|
-
diff_params: dict | ps.BaseDifferentiation,
|
|
286
|
-
feat_params: dict | ps.feature_library.base.BaseFeatureLibrary,
|
|
287
|
-
opt_params: dict | ps.BaseOptimizer,
|
|
288
|
-
) -> ps.SINDy:
|
|
289
|
-
"""Build a model with object parameters dictionaries
|
|
290
|
-
|
|
291
|
-
e.g. {"kind": "finitedifference"} instead of FiniteDifference()
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
def finalize_param(lookup_func, pdict, lookup_key):
|
|
295
|
-
try:
|
|
296
|
-
cls_name = pdict.pop(lookup_key)
|
|
297
|
-
except AttributeError:
|
|
298
|
-
cls_name = pdict.vals.pop(lookup_key)
|
|
299
|
-
pdict = pdict.vals
|
|
300
|
-
|
|
301
|
-
param_cls = lookup_func(cls_name)
|
|
302
|
-
param_final = param_cls(**pdict)
|
|
303
|
-
pdict[lookup_key] = cls_name
|
|
304
|
-
return param_final
|
|
305
|
-
|
|
306
|
-
if isinstance(diff_params, ps.BaseDifferentiation):
|
|
307
|
-
diff = diff_params
|
|
308
|
-
else:
|
|
309
|
-
diff = finalize_param(diff_lookup, diff_params, "diffcls")
|
|
310
|
-
if isinstance(feat_params, ps.feature_library.base.BaseFeatureLibrary):
|
|
311
|
-
features = feat_params
|
|
312
|
-
else:
|
|
313
|
-
features = finalize_param(feature_lookup, feat_params, "featcls")
|
|
314
|
-
if isinstance(opt_params, ps.BaseOptimizer):
|
|
315
|
-
opt = opt_params
|
|
316
|
-
else:
|
|
317
|
-
opt = finalize_param(opt_lookup, opt_params, "optcls")
|
|
318
|
-
return ps.SINDy(
|
|
319
|
-
differentiation_method=diff,
|
|
320
|
-
optimizer=opt,
|
|
321
|
-
t_default=dt, # type: ignore
|
|
322
|
-
feature_library=features,
|
|
323
|
-
feature_names=input_features,
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
|
|
327
286
|
def _simulate_test_data(
|
|
328
287
|
model: _BaseSINDy, dt: float, x_test: Float2D
|
|
329
288
|
) -> SINDyTrialUpdate:
|
sindy_exp/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
sindy_exp/__init__.py,sha256=OH4tzHuhWmgXexz-QbXMZtuzkv9kZYyIFBUmDa9n73Q,741
|
|
2
|
+
sindy_exp/_data.py,sha256=_PMoXN4JHhR0bZi2ivD9_EB7rUCJC2xhDuRd4J8edgM,9233
|
|
3
|
+
sindy_exp/_diffrax_solver.py,sha256=c-IjDqaAwLj0rZ4vIm8pMm1U_9K6YiE1XCaL72NAVpI,2362
|
|
4
|
+
sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
|
|
5
|
+
sindy_exp/_odes.py,sha256=cMUNvS6TL4_nRLrjK2MBoLzgpM__gu4NwzY2n4gfQn8,6513
|
|
6
|
+
sindy_exp/_plotting.py,sha256=dpcqAXKzb0mSVl0p2WyMadVoGhQi43oL6ZqsbuheEuk,17470
|
|
7
|
+
sindy_exp/_typing.py,sha256=_KKtGcXOZmlR3Fg77G6TmlH5eKDwnoHjusex2Gxlf_4,4196
|
|
8
|
+
sindy_exp/_utils.py,sha256=zR9Npjl8PeeSp6MHHD5lL3q37gDXIOLTEfDqGdY_2fM,11341
|
|
9
|
+
sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
|
|
10
|
+
sindy_exp/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
sindy_exp-0.2.2.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
|
|
12
|
+
sindy_exp-0.2.2.dist-info/METADATA,sha256=JU-BigVCn7w2WhWDRfA3d32RyEz7Yb4EKjxc1ng-a4o,4514
|
|
13
|
+
sindy_exp-0.2.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
+
sindy_exp-0.2.2.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
|
|
15
|
+
sindy_exp-0.2.2.dist-info/RECORD,,
|
sindy_exp-0.2.0.dist-info/RECORD
DELETED
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
sindy_exp/__init__.py,sha256=F1bJz9Gzk2nPB6DGjEa0qZcaQPTb-Yhh0ZnO9TcQci0,689
|
|
2
|
-
sindy_exp/_data.py,sha256=xEQByYeL2ejPdCpFVP9oH48Faoiz1E3OjRhK06lmkFo,6512
|
|
3
|
-
sindy_exp/_diffrax_solver.py,sha256=YmKF-U1fkxW4z4AbXW_qm1sX6eEy1BGzKbqL63dxo2M,2705
|
|
4
|
-
sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
|
|
5
|
-
sindy_exp/_odes.py,sha256=RN97nh1Y_6DQCLPfRTyCCJq3exxBq9pfVHVPTxr0na4,8830
|
|
6
|
-
sindy_exp/_plotting.py,sha256=dpcqAXKzb0mSVl0p2WyMadVoGhQi43oL6ZqsbuheEuk,17470
|
|
7
|
-
sindy_exp/_typing.py,sha256=hIOj270YkoITi_qJ8xTO1EaBeoVzcDidIGMXuC0QHzM,4046
|
|
8
|
-
sindy_exp/_utils.py,sha256=RnMoNV_N7ubWVMbgdiJaRGo437DvWUcOIn4fP2rhwJI,12717
|
|
9
|
-
sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
|
|
10
|
-
sindy_exp-0.2.0.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
|
|
11
|
-
sindy_exp-0.2.0.dist-info/METADATA,sha256=_BRUa1zXAqRJtMWeIcyUDyx74eo_VeAwu7V1NVhbhZM,4514
|
|
12
|
-
sindy_exp-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
13
|
-
sindy_exp-0.2.0.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
|
|
14
|
-
sindy_exp-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|