sindy-exp 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sindy_exp/__init__.py CHANGED
@@ -7,14 +7,16 @@ from ._plotting import (
7
7
  plot_test_trajectory,
8
8
  plot_training_data,
9
9
  )
10
- from ._typing import DynamicsTrialData, ProbData
10
+ from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData, SimProbData
11
11
  from ._utils import coeff_metrics, integration_metrics, pred_metrics
12
12
 
13
13
  __all__ = [
14
14
  "gen_data",
15
15
  "fit_eval",
16
16
  "ProbData",
17
+ "SimProbData",
17
18
  "DynamicsTrialData",
19
+ "FullDynamicsTrialData",
18
20
  "coeff_metrics",
19
21
  "pred_metrics",
20
22
  "integration_metrics",
sindy_exp/_data.py CHANGED
@@ -1,15 +1,17 @@
1
+ from importlib import resources
1
2
  from logging import getLogger
2
- from typing import Any, Callable, Optional, cast
3
+ from typing import Callable, Optional, cast
3
4
 
4
5
  import dysts.flows
5
6
  import dysts.systems
6
7
  import numpy as np
7
8
  import scipy
9
+ import sympy as sp
10
+ from dysts.base import DynSys
8
11
 
9
12
  from ._dysts_to_sympy import dynsys_to_sympy
10
- from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
11
13
  from ._plotting import plot_training_data
12
- from ._typing import Float1D, ProbData
14
+ from ._typing import ExperimentResult, Float1D, SimProbData
13
15
  from ._utils import _sympy_expr_to_feat_coeff
14
16
 
15
17
  try:
@@ -21,21 +23,12 @@ except ImportError:
21
23
 
22
24
  INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
23
25
  MOD_LOG = getLogger(__name__)
26
+ LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
24
27
 
25
28
  ODE_CLASSES = {
26
29
  klass.lower(): getattr(dysts.flows, klass)
27
30
  for klass in dysts.systems.get_attractor_list()
28
31
  }
29
- ODE_CLASSES.update(
30
- {
31
- "lotkavolterra": LotkaVolterra,
32
- "sho": SHO,
33
- "cubicho": CubicHO,
34
- "hopf": Hopf,
35
- "vanderpol": VanDerPol,
36
- "kinematics": Kinematics,
37
- }
38
- )
39
32
 
40
33
 
41
34
  def gen_data(
@@ -49,7 +42,7 @@ def gen_data(
49
42
  t_end: float = 10,
50
43
  display: bool = False,
51
44
  array_namespace: str = "numpy",
52
- ) -> dict[str, Any]:
45
+ ) -> ExperimentResult[tuple[list[SimProbData], list[dict[sp.Expr, float]]]]:
53
46
  """Generate random training and test data
54
47
 
55
48
  An Experiment step according to the mitosis experiment runner.
@@ -82,7 +75,7 @@ def gen_data(
82
75
  coeff_true = _sympy_expr_to_feat_coeff(sp_expr)
83
76
  rhsfunc = lambda t, X: dyst_sys.rhs(X, t) # noqa: E731
84
77
  try:
85
- x0_center = dyst_sys.ic
78
+ x0_center = cast(Float1D, dyst_sys.ic)
86
79
  except KeyError:
87
80
  x0_center = np.zeros((len(input_features)), dtype=np.float64)
88
81
  try:
@@ -95,7 +88,7 @@ def gen_data(
95
88
  noise_abs = 0.1
96
89
 
97
90
  MOD_LOG.info(f"Generating {n_trajectories} trajectories of f{system}")
98
- prob_data_list: list[ProbData] = []
91
+ prob_data_list: list[SimProbData] = []
99
92
  if array_namespace == "numpy":
100
93
  feature_names = [feat.name for feat in input_features]
101
94
  for _ in range(n_trajectories):
@@ -115,20 +108,20 @@ def gen_data(
115
108
  prob_data_list.append(prob)
116
109
  elif array_namespace == "jax":
117
110
  try:
118
- globals()["_gen_data_jax"]
119
- except KeyError:
111
+ jax # type: ignore
112
+ except NameError:
120
113
  raise ImportError(
121
114
  "jax data generation requested but diffrax or sympy2jax not"
122
115
  " installed"
123
116
  )
124
- this_seed = jax.random.PRNGKey(seed)
117
+ this_seed = jax.random.PRNGKey(seed) # type: ignore
125
118
  for _ in range(n_trajectories):
126
- this_seed, _ = jax.random.split(this_seed)
127
- prob = _gen_data_jax(
119
+ this_seed, _ = jax.random.split(this_seed) # type: ignore
120
+ prob = _gen_data_jax( # type: ignore
128
121
  sp_expr,
129
122
  input_features,
130
123
  this_seed,
131
- x0_center=x0_center,
124
+ x0_center=x0_center, # type: ignore # numpy->jax
132
125
  nonnegative=nonnegative,
133
126
  ic_stdev=ic_stdev,
134
127
  noise_abs=noise_abs,
@@ -143,10 +136,12 @@ def gen_data(
143
136
  )
144
137
  if display and prob_data_list:
145
138
  sample = prob_data_list[0]
139
+ assert sample.x_train_true is not None # typing
146
140
  figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
147
141
  figs[0].suptitle("Sample Trajectory")
142
+
148
143
  return {
149
- "data": {"trajectories": prob_data_list, "coeff_true": coeff_true},
144
+ "data": (prob_data_list, coeff_true),
150
145
  "main": f"{n_trajectories} trajectories of {rhsfunc}",
151
146
  "metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
152
147
  }
@@ -163,7 +158,7 @@ def _gen_data(
163
158
  nonnegative: bool,
164
159
  dt: float,
165
160
  t_end: float,
166
- ) -> ProbData:
161
+ ) -> SimProbData:
167
162
  rng = np.random.default_rng(seed)
168
163
  t_train = np.arange(0, t_end, dt)
169
164
  t_train_span = (t_train[0], t_train[-1])
@@ -187,8 +182,9 @@ def _gen_data(
187
182
  noise_abs = np.sqrt(_signal_avg_power(x_train) * noise_rel)
188
183
  x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
189
184
 
190
- return ProbData(
191
- dt, t_train, x_train, x_train_true, x_train_true_dot, input_features
185
+ assert noise_abs is not None # typing
186
+ return SimProbData(
187
+ t_train, x_train, input_features, x_train_true, x_train_true_dot, noise_abs
192
188
  )
193
189
 
194
190
 
@@ -200,3 +196,116 @@ def _max_amplitude(signal: np.ndarray, axis: int) -> float:
200
196
 
201
197
  def _signal_avg_power(signal: np.ndarray) -> float:
202
198
  return np.square(signal).mean()
199
+
200
+
201
+ def _register_dyst(klass: type[DynSys]) -> type[DynSys]:
202
+ """Register a custom dysts DynSys class for use in sindy_exp data generation."""
203
+ ODE_CLASSES[klass.__name__.lower()] = klass
204
+ return klass
205
+
206
+
207
+ @_register_dyst
208
+ class LotkaVolterra(DynSys):
209
+ """Lotka-Volterra (predator-prey) dynamical system."""
210
+
211
+ nonnegative = True
212
+
213
+ def __init__(self):
214
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
215
+
216
+ @staticmethod
217
+ def _rhs( # type: ignore # dysts
218
+ x, y, t: float, alpha, beta, gamma, delta
219
+ ) -> np.ndarray:
220
+ """LV dynamics
221
+
222
+ Args:
223
+ x: prey population
224
+ y: predator population
225
+ t: time (ignored, since autonomous)
226
+ alpha: prey growth rate
227
+ beta: predation rate
228
+ delta: predator reproduction rate
229
+ gamma: predator death rate
230
+ """
231
+ dxdt = alpha * x - beta * x * y
232
+ dydt = delta * x * y - gamma * y
233
+ return np.array([dxdt, dydt])
234
+
235
+
236
+ @_register_dyst
237
+ class Hopf(DynSys):
238
+ """Hopf normal form dynamical system."""
239
+
240
+ def __init__(self):
241
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
242
+
243
+ @staticmethod
244
+ def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray: # type: ignore # dysts
245
+ dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
246
+ dydt = omega * x + mu * y - A * (x**2 * y + y**3)
247
+ return np.array([dxdt, dydt])
248
+
249
+
250
+ @_register_dyst
251
+ class SHO(DynSys):
252
+ """Linear damped simple harmonic oscillator"""
253
+
254
+ def __init__(self):
255
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
256
+
257
+ @staticmethod
258
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
259
+ dxdt = a * x + b * y
260
+ dydt = c * x + d * y
261
+ return np.array([dxdt, dydt])
262
+
263
+
264
+ @_register_dyst
265
+ class CubicHO(DynSys):
266
+ """Cubic damped harmonic oscillator."""
267
+
268
+ def __init__(self):
269
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
270
+
271
+ @staticmethod
272
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
273
+ dxdt = a * x**3 + b * y**3
274
+ dydt = c * x**3 + d * y**3
275
+ return np.array([dxdt, dydt])
276
+
277
+
278
+ @_register_dyst
279
+ class VanDerPol(DynSys):
280
+ """Van der Pol oscillator.
281
+
282
+ dx/dt = y
283
+ dy/dt = mu * (1 - x^2) * y - x
284
+ """
285
+
286
+ def __init__(self):
287
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
288
+
289
+ @staticmethod
290
+ def _rhs(x, x_dot, t: float, mu) -> np.ndarray: # type: ignore # dysts
291
+ dxdt = x_dot
292
+ dx2dt2 = mu * (1 - x**2) * x_dot - x
293
+ return np.array([dxdt, dx2dt2])
294
+
295
+
296
+ @_register_dyst
297
+ class Kinematics(DynSys):
298
+ """One-dimensional kinematics with constant acceleration.
299
+
300
+ dx/dt = v
301
+ dv/dt = a
302
+ """
303
+
304
+ def __init__(self):
305
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
306
+
307
+ @staticmethod
308
+ def _rhs(x, v, t: float, a) -> np.ndarray: # type: ignore # dysts
309
+ dxdt = v
310
+ dvdt = a
311
+ return np.array([dxdt, dvdt])
@@ -6,7 +6,7 @@ import jax.numpy as jnp
6
6
  import sympy2jax
7
7
  from sympy import Expr, Symbol
8
8
 
9
- from ._typing import ProbData
9
+ from ._typing import SimProbData
10
10
 
11
11
  jax.config.update("jax_enable_x64", True)
12
12
 
@@ -22,7 +22,7 @@ def _gen_data_jax(
22
22
  nonnegative: bool,
23
23
  dt: float,
24
24
  t_end: float,
25
- ) -> ProbData:
25
+ ) -> SimProbData:
26
26
  rhstree = sympy2jax.SymbolicModule(exprs)
27
27
 
28
28
  def ode_sys(t, state, args):
@@ -71,6 +71,8 @@ def _gen_data_jax(
71
71
  if noise_abs is None:
72
72
  assert noise_rel is not None # force type narrowing
73
73
  noise_abs = float(jnp.sqrt(_signal_avg_power(x_train_true)) * noise_rel)
74
+ else:
75
+ noise_rel = noise_abs / float(jnp.sqrt(_signal_avg_power(x_train_true)))
74
76
 
75
77
  x_train = x_train_true + jax.random.normal(key, x_train_true.shape) * noise_abs
76
78
 
@@ -78,27 +80,16 @@ def _gen_data_jax(
78
80
  x_train_true_dot = jnp.array([ode_sys(0, xi, None) for xi in x_train_true])
79
81
 
80
82
  stringy_features = [sym.name for sym in input_features]
81
- return ProbData(
82
- dt, t_train, x_train, x_train_true, x_train_true_dot, stringy_features, sol
83
+ return SimProbData(
84
+ t_train, # type: ignore # jax->numpy
85
+ x_train, # type: ignore # jax->numpy
86
+ stringy_features,
87
+ x_train_true, # type: ignore # jax->numpy
88
+ x_train_true_dot, # type: ignore # jax->numpy
89
+ noise_abs,
90
+ sol,
83
91
  )
84
92
 
85
93
 
86
94
  def _signal_avg_power(signal: jax.Array) -> jax.Array:
87
95
  return jnp.square(signal).mean()
88
-
89
-
90
- ## % # noqa:E266
91
- if __name__ == "__main__":
92
- # Debug example
93
- from sindy_exp._data import gen_data
94
-
95
- data_dict = gen_data(
96
- "valliselnino",
97
- seed=50,
98
- n_trajectories=1,
99
- ic_stdev=3,
100
- noise_rel=0.1,
101
- display=True,
102
- array_namespace="jax",
103
- )
104
- print(data_dict["input_features"])
sindy_exp/_odes.py CHANGED
@@ -1,12 +1,9 @@
1
- from importlib import resources
2
1
  from logging import getLogger
3
- from typing import Callable, TypeVar
2
+ from typing import Any, Literal, TypeVar, cast, overload
4
3
 
5
4
  import matplotlib.pyplot as plt
6
5
  import numpy as np
7
- import pysindy as ps
8
6
  import sympy as sp
9
- from dysts.base import DynSys
10
7
 
11
8
  from ._plotting import (
12
9
  compare_coefficient_plots_from_dicts,
@@ -15,8 +12,9 @@ from ._plotting import (
15
12
  )
16
13
  from ._typing import (
17
14
  DynamicsTrialData,
15
+ ExperimentResult,
18
16
  FullDynamicsTrialData,
19
- ProbData,
17
+ SimProbData,
20
18
  SINDyTrialUpdate,
21
19
  _BaseSINDy,
22
20
  )
@@ -41,154 +39,46 @@ metric_ordering = {
41
39
  T = TypeVar("T", bound=int)
42
40
  DType = TypeVar("DType", bound=np.dtype)
43
41
  MOD_LOG = getLogger(__name__)
44
- LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
45
42
 
46
43
 
47
- def _add_forcing(
48
- forcing_func: Callable[[float], np.ndarray[tuple[T], DType]],
49
- auto_func: Callable[
50
- [float, np.ndarray[tuple[T], DType]], np.ndarray[tuple[T], DType]
51
- ],
52
- ) -> Callable[[float, np.ndarray], np.ndarray]:
53
- """Add a time-dependent forcing term to a rhs func
54
-
55
- Args:
56
- forcing_func: The forcing function to add
57
- auto_func: An existing rhs func for solve_ivp
58
-
59
- Returns:
60
- A rhs function for integration
61
- """
62
-
63
- def sum_of_terms(
64
- t: float, state: np.ndarray[tuple[T], DType]
65
- ) -> np.ndarray[tuple[T], DType]:
66
- return np.array(forcing_func(t)) + np.array(auto_func(t, state))
67
-
68
- return sum_of_terms
69
-
70
-
71
- class LotkaVolterra(DynSys):
72
- """Lotka-Volterra (predator-prey) dynamical system."""
73
-
74
- nonnegative = True
75
-
76
- def __init__(self):
77
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
78
-
79
- @staticmethod
80
- def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
81
- """LV dynamics
82
-
83
- Args:
84
- x: prey population
85
- y: predator population
86
- t: time (ignored, since autonomous)
87
- alpha: prey growth rate
88
- beta: predation rate
89
- delta: predator reproduction rate
90
- gamma: predator death rate
91
- """
92
- dxdt = alpha * x - beta * x * y
93
- dydt = delta * x * y - gamma * y
94
- return np.array([dxdt, dydt])
95
-
96
-
97
- class Hopf(DynSys):
98
- """Hopf normal form dynamical system."""
99
-
100
- def __init__(self):
101
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
102
-
103
- @staticmethod
104
- def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
105
- dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
106
- dydt = omega * x + mu * y - A * (x**2 * y + y**3)
107
- return np.array([dxdt, dydt])
108
-
109
-
110
- class SHO(DynSys):
111
- """Linear damped simple harmonic oscillator"""
112
-
113
- def __init__(self):
114
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
115
-
116
- @staticmethod
117
- def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
118
- dxdt = a * x + b * y
119
- dydt = c * x + d * y
120
- return np.array([dxdt, dydt])
121
-
122
-
123
- class CubicHO(DynSys):
124
- """Cubic damped harmonic oscillator."""
125
-
126
- def __init__(self):
127
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
128
-
129
- @staticmethod
130
- def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
131
- dxdt = a * x**3 + b * y**3
132
- dydt = c * x**3 + d * y**3
133
- return np.array([dxdt, dydt])
134
-
135
-
136
- class VanDerPol(DynSys):
137
- """Van der Pol oscillator.
138
-
139
- dx/dt = y
140
- dy/dt = mu * (1 - x^2) * y - x
141
- """
142
-
143
- def __init__(self):
144
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
145
-
146
- @staticmethod
147
- def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
148
- dxdt = x_dot
149
- dx2dt2 = mu * (1 - x**2) * x_dot - x
150
- return np.array([dxdt, dx2dt2])
151
-
152
-
153
- class Kinematics(DynSys):
154
- """One-dimensional kinematics with constant acceleration.
155
-
156
- dx/dt = v
157
- dv/dt = a
158
- """
44
+ @overload
45
+ def fit_eval(
46
+ data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
47
+ model: _BaseSINDy,
48
+ simulations: Literal[False],
49
+ display: bool,
50
+ ) -> ExperimentResult[DynamicsTrialData]: ...
159
51
 
160
- def __init__(self):
161
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
162
52
 
163
- @staticmethod
164
- def _rhs(x, v, t: float, a) -> np.ndarray:
165
- dxdt = v
166
- dvdt = a
167
- return np.array([dxdt, dvdt])
53
+ @overload
54
+ def fit_eval(
55
+ data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
56
+ model: _BaseSINDy,
57
+ simulations: Literal[True],
58
+ display: bool,
59
+ ) -> ExperimentResult[FullDynamicsTrialData]: ...
168
60
 
169
61
 
170
62
  def fit_eval(
171
- data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
172
- model: _BaseSINDy,
63
+ data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
64
+ model: Any,
173
65
  simulations: bool = True,
174
66
  display: bool = True,
175
- return_all: bool = True,
176
- ) -> dict | tuple[dict, DynamicsTrialData | FullDynamicsTrialData]:
67
+ ) -> ExperimentResult:
177
68
  """Fit and evaluate a SINDy model on a set of trajectories.
178
69
 
179
70
  Args:
180
71
  data: Tuple of (trajectories, true_equations), where ``trajectories`` is
181
- a list of ProbData objects and ``true_equations`` is a list of
72
+ a list of SimProbData objects and ``true_equations`` is a list of
182
73
  dictionaries mapping SymPy symbols to their true coefficients for
183
74
  each state coordinate.
184
75
  model: A SINDy-like model implementing the _BaseSINDy protocol.
185
76
  simulations: Whether to run forward simulations for evaluation.
186
77
  display: Whether to generate plots as part of evaluation.
187
- return_all: If True, return a dictionary containing metrics and the
188
- assembled DynamicsTrialData; otherwise return only the metrics
189
- dictionary.
190
78
  """
191
-
79
+ model = cast(_BaseSINDy, model)
80
+ for trajectory in data[0]:
81
+ assert trajectory.x_train_true is not None
192
82
  trajectories, true_equations = data
193
83
  input_features = trajectories[0].input_features
194
84
 
@@ -198,12 +88,16 @@ def fit_eval(
198
88
 
199
89
  MOD_LOG.info(f"Fitting a model: {model}")
200
90
  coeff_true_dicts, coeff_est_dicts = unionize_coeff_dicts(model, true_equations)
201
- if isinstance(model.feature_library, ps.WeakPDELibrary):
91
+
92
+ # Special workaround for pysindy's legacy WeakPDELibrary
93
+ if hasattr(model.feature_library, "K"):
202
94
  # WeakPDE library fails to simulate, so insert nonweak library
203
95
  # to Pipeline and SINDy model.
204
96
  inner_lib = model.feature_library.function_library
205
97
  model.feature_library = inner_lib # type: ignore # TODO: Fix in pysindy
206
- if isinstance(model, ps.SINDy) and hasattr(
98
+
99
+ # Special workaround for pysindy's bad (soon to be legacy) differentiation API
100
+ if hasattr(model, "differentiation_method") and hasattr(
207
101
  model.differentiation_method, "smoothed_x_"
208
102
  ):
209
103
  smooth_x = []
@@ -212,6 +106,7 @@ def fit_eval(
212
106
  smooth_x.append(model.differentiation_method.smoothed_x_)
213
107
  else: # using WeakPDELibrary
214
108
  smooth_x = x_train
109
+
215
110
  trial_data = DynamicsTrialData(
216
111
  trajectories=trajectories,
217
112
  true_equations=coeff_true_dicts,
@@ -226,7 +121,7 @@ def fit_eval(
226
121
  sims: list[SINDyTrialUpdate] = []
227
122
  integration_metric_list: list[dict[str, float | np.floating]] = []
228
123
  for traj in trajectories:
229
- sim = _simulate_test_data(model, traj.dt, traj.x_train_true)
124
+ sim = _simulate_test_data(model, traj.t_train, traj.x_train_true)
230
125
  sims.append(sim)
231
126
  integration_metric_list.append(
232
127
  integration_metrics(
@@ -237,9 +132,9 @@ def fit_eval(
237
132
  )
238
133
  )
239
134
 
240
- agg_integration_metrics: dict[str, float | np.floating] = {}
135
+ agg_integration_metrics: dict[str, float] = {}
241
136
  for key in integration_metric_list[0].keys():
242
- values = [m[key] for m in integration_metric_list]
137
+ values = cast(list[float], [m[key] for m in integration_metric_list])
243
138
  agg_integration_metrics[key] = float(np.mean(values))
244
139
  metrics.update(agg_integration_metrics)
245
140
 
@@ -264,9 +159,8 @@ def fit_eval(
264
159
  figs=(fig_composite, fig_by_coord_1d),
265
160
  coord_names=input_features,
266
161
  )
267
- if return_all:
268
- return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
269
- return metrics
162
+
163
+ return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
270
164
 
271
165
 
272
166
  def plot_ode_panel(trial_data: DynamicsTrialData):
sindy_exp/_plotting.py CHANGED
@@ -57,7 +57,7 @@ def plot_coefficients(
57
57
  feature_names: Sequence[str],
58
58
  ax: Axes,
59
59
  **heatmap_kws,
60
- ) -> None:
60
+ ) -> Axes:
61
61
  """Plot a set of dynamical system coefficients in a heatmap.
62
62
 
63
63
  Args:
@@ -162,6 +162,7 @@ def _compare_coefficient_plots_impl(
162
162
  1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True
163
163
  )
164
164
  fig.tight_layout()
165
+ assert axs is not None # type narrowing
165
166
 
166
167
  vmax = signed_root(max_val)
167
168
 
@@ -275,7 +276,12 @@ def _plot_training_trajectory(
275
276
  """
276
277
  if x_train.shape[1] == 2:
277
278
  ax.plot(
278
- x_true[:, 0], x_true[:, 1], ".", label="True", color=COLOR.TRUE, **PLOT_KWS
279
+ x_true[:, 0],
280
+ x_true[:, 1],
281
+ ".",
282
+ label="True",
283
+ color=COLOR.TRUE,
284
+ **PLOT_KWS, # type: ignore[arg-type]
279
285
  )
280
286
  ax.plot(
281
287
  x_train[:, 0],
@@ -283,7 +289,7 @@ def _plot_training_trajectory(
283
289
  ".",
284
290
  label="Measured",
285
291
  color=COLOR.MEAS,
286
- **PLOT_KWS,
292
+ **PLOT_KWS, # type: ignore[arg-type]
287
293
  )
288
294
  if (
289
295
  x_smooth is not None
@@ -295,7 +301,7 @@ def _plot_training_trajectory(
295
301
  ".",
296
302
  label="Smoothed",
297
303
  color=COLOR.EST,
298
- **PLOT_KWS,
304
+ **PLOT_KWS, # type: ignore[arg-type]
299
305
  )
300
306
  if labels:
301
307
  ax.set(xlabel="$x_0$", ylabel="$x_1$")
@@ -308,7 +314,7 @@ def _plot_training_trajectory(
308
314
  x_true[:, 2],
309
315
  color=COLOR.TRUE,
310
316
  label="True values",
311
- **PLOT_KWS,
317
+ **PLOT_KWS, # type: ignore[arg-type]
312
318
  )
313
319
 
314
320
  ax.plot(
sindy_exp/_typing.py CHANGED
@@ -1,12 +1,13 @@
1
1
  from collections import defaultdict
2
+ from collections.abc import Mapping
2
3
  from dataclasses import dataclass
3
4
  from typing import (
4
5
  Any,
5
6
  Callable,
6
7
  Literal,
7
- NamedTuple,
8
8
  Optional,
9
9
  Protocol,
10
+ TypedDict,
10
11
  TypeVar,
11
12
  overload,
12
13
  )
@@ -27,6 +28,14 @@ FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
27
28
  TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
28
29
 
29
30
 
31
+ class ExperimentResult[T](TypedDict):
32
+ """Results from a SINDy ODE experiment."""
33
+
34
+ metrics: Mapping[str, float | None]
35
+ data: T
36
+ main: object
37
+
38
+
30
39
  class _BaseSINDy(Protocol):
31
40
  optimizer: Any
32
41
  feature_library: Any
@@ -62,23 +71,38 @@ class _BaseSINDy(Protocol):
62
71
  self, precision: int, fmt: Literal["sympy"]
63
72
  ) -> list[dict[Expr, float]]: ...
64
73
 
74
+ @overload
75
+ def print(self, **kwargs) -> None: ...
76
+
77
+ @overload
65
78
  def print(self, precision: int, **kwargs) -> None: ...
66
79
 
67
80
  def get_feature_names(self) -> list[str]: ...
68
81
 
69
82
 
70
- class ProbData(NamedTuple):
71
- """Data bundle for a single trajectory.
83
+ @dataclass
84
+ class ProbData:
85
+ """Represents a single trajectory's data.
72
86
 
73
- Represents a trajectory's training data and associated metadata.
87
+ For measured data, only t_train, x_train, and input_features are required.
74
88
  """
75
89
 
76
- dt: float
77
90
  t_train: Float1D
78
91
  x_train: Float2D
92
+ input_features: list[str]
93
+
94
+
95
+ @dataclass
96
+ class SimProbData(ProbData):
97
+ """For simulated data, the noiseless trajectory is known.
98
+
99
+ Optionally includes the integrator solution object for evaluating
100
+ at other points.
101
+ """
102
+
79
103
  x_train_true: Float2D
80
104
  x_train_true_dot: Float2D
81
- input_features: list[str]
105
+ noise_abs: float
82
106
  integrator: Optional[Any] = None # diffrax.Solution
83
107
 
84
108
 
@@ -138,7 +162,7 @@ class NestedDict(defaultdict):
138
162
 
139
163
  @dataclass
140
164
  class DynamicsTrialData:
141
- trajectories: list[ProbData]
165
+ trajectories: list[SimProbData]
142
166
  true_equations: list[dict[sp.Expr, float]]
143
167
  sindy_equations: list[dict[sp.Expr, float]]
144
168
  model: _BaseSINDy
sindy_exp/_utils.py CHANGED
@@ -149,7 +149,7 @@ def opt_lookup(kind):
149
149
  def coeff_metrics(
150
150
  coeff_est_dicts: list[dict[sp.Expr, float]],
151
151
  coeff_true_dicts: list[dict[sp.Expr, float]],
152
- ) -> dict[str, float | np.floating]:
152
+ ) -> dict[str, float]:
153
153
  """Compute coefficient metrics from aligned coefficient dictionaries.
154
154
 
155
155
  Both arguments are expected to be lists of coefficient dictionaries sharing
@@ -182,14 +182,18 @@ def coeff_metrics(
182
182
  coefficients[row_ind, col_ind] = est_row[feat]
183
183
 
184
184
  metrics: dict[str, float | np.floating] = {}
185
- metrics["coeff_precision"] = sklearn.metrics.precision_score(
186
- coeff_true.flatten() != 0, coefficients.flatten() != 0
185
+ metrics["coeff_precision"] = float(
186
+ sklearn.metrics.precision_score(
187
+ coeff_true.flatten() != 0, coefficients.flatten() != 0
188
+ )
187
189
  )
188
- metrics["coeff_recall"] = sklearn.metrics.recall_score(
189
- coeff_true.flatten() != 0, coefficients.flatten() != 0
190
+ metrics["coeff_recall"] = float(
191
+ sklearn.metrics.recall_score(
192
+ coeff_true.flatten() != 0, coefficients.flatten() != 0
193
+ )
190
194
  )
191
- metrics["coeff_f1"] = sklearn.metrics.f1_score(
192
- coeff_true.flatten() != 0, coefficients.flatten() != 0
195
+ metrics["coeff_f1"] = float(
196
+ sklearn.metrics.f1_score(coeff_true.flatten() != 0, coefficients.flatten() != 0)
193
197
  )
194
198
  metrics["coeff_mse"] = sklearn.metrics.mean_squared_error(
195
199
  coeff_true.flatten(), coefficients.flatten()
@@ -198,7 +202,7 @@ def coeff_metrics(
198
202
  coeff_true.flatten(), coefficients.flatten()
199
203
  )
200
204
  metrics["main"] = metrics["coeff_f1"]
201
- return metrics
205
+ return {k: float(v) for k, v in metrics.items()}
202
206
 
203
207
 
204
208
  def pred_metrics(
@@ -279,53 +283,8 @@ def unionize_coeff_dicts(
279
283
  return true_aligned, est_aligned
280
284
 
281
285
 
282
- def make_model(
283
- input_features: list[str],
284
- dt: float,
285
- diff_params: dict | ps.BaseDifferentiation,
286
- feat_params: dict | ps.feature_library.base.BaseFeatureLibrary,
287
- opt_params: dict | ps.BaseOptimizer,
288
- ) -> ps.SINDy:
289
- """Build a model with object parameters dictionaries
290
-
291
- e.g. {"kind": "finitedifference"} instead of FiniteDifference()
292
- """
293
-
294
- def finalize_param(lookup_func, pdict, lookup_key):
295
- try:
296
- cls_name = pdict.pop(lookup_key)
297
- except AttributeError:
298
- cls_name = pdict.vals.pop(lookup_key)
299
- pdict = pdict.vals
300
-
301
- param_cls = lookup_func(cls_name)
302
- param_final = param_cls(**pdict)
303
- pdict[lookup_key] = cls_name
304
- return param_final
305
-
306
- if isinstance(diff_params, ps.BaseDifferentiation):
307
- diff = diff_params
308
- else:
309
- diff = finalize_param(diff_lookup, diff_params, "diffcls")
310
- if isinstance(feat_params, ps.feature_library.base.BaseFeatureLibrary):
311
- features = feat_params
312
- else:
313
- features = finalize_param(feature_lookup, feat_params, "featcls")
314
- if isinstance(opt_params, ps.BaseOptimizer):
315
- opt = opt_params
316
- else:
317
- opt = finalize_param(opt_lookup, opt_params, "optcls")
318
- return ps.SINDy(
319
- differentiation_method=diff,
320
- optimizer=opt,
321
- t_default=dt, # type: ignore
322
- feature_library=features,
323
- feature_names=input_features,
324
- )
325
-
326
-
327
286
  def _simulate_test_data(
328
- model: _BaseSINDy, dt: float, x_test: Float2D
287
+ model: _BaseSINDy, t_test: Float1D, x_test: Float2D
329
288
  ) -> SINDyTrialUpdate:
330
289
  """Add simulation data to grid_data
331
290
 
@@ -333,7 +292,6 @@ def _simulate_test_data(
333
292
  Returns:
334
293
  Complete GridPointData
335
294
  """
336
- t_test = cast(Float1D, np.arange(0, len(x_test) * dt, step=dt))
337
295
  t_sim = t_test
338
296
  try:
339
297
 
sindy_exp/py.typed ADDED
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sindy-exp
3
- Version: 0.2.1
3
+ Version: 0.3.0
4
4
  Summary: A basic library for constructing dynamics experiments
5
5
  Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
6
6
  License: MIT License
@@ -34,7 +34,7 @@ Classifier: Intended Audience :: Science/Research
34
34
  Classifier: License :: OSI Approved :: MIT License
35
35
  Classifier: Natural Language :: English
36
36
  Classifier: Operating System :: POSIX :: Linux
37
- Requires-Python: >=3.10
37
+ Requires-Python: >=3.12
38
38
  Description-Content-Type: text/markdown
39
39
  License-File: LICENSE
40
40
  Requires-Dist: matplotlib
@@ -43,6 +43,7 @@ Requires-Dist: seaborn
43
43
  Requires-Dist: scipy
44
44
  Requires-Dist: sympy
45
45
  Requires-Dist: dysts
46
+ Requires-Dist: scikit-learn
46
47
  Provides-Extra: jax
47
48
  Requires-Dist: jax[cuda12]; extra == "jax"
48
49
  Requires-Dist: diffrax; extra == "jax"
@@ -64,14 +65,20 @@ Requires-Dist: tomli; extra == "dev"
64
65
  Requires-Dist: pysindy>=2.1.0; extra == "dev"
65
66
  Dynamic: license-file
66
67
 
67
- # Dynamics Experiments
68
+ # Overview
68
69
 
69
- A library for constructing dynamics experiments.
70
- This includes data generation and plotting/evaluation.
70
+ A library for constructing dynamics experiments from the dynamics models in the `dysts` package.
71
+ This includes data generation and model evaluation.
72
+ The first contribution is the static typing of trajectory data (`ProbData`) that, I believe, provides the necessary information to be useful in evaluating a wide variety of dynamics/time-series learning methods.
73
+ The second contribution is the collection of utility functions for designing dynamics learning experiments.
74
+ The third contribution is the collection of such experiments for evaluating dynamics/time-series learning models that meet the `BaseSINDy` API.
75
+
76
+ It aims to (a) be amenable to both `numpy` and `jax` arrays, (b) be usable by any dynamics/time-series learning models that meet the `BaseSINDy` or scikit-time API.
77
+ Internally, this package is used/will be used in benchmarking pysindy runtime/memory usage and choosing default hyperparameters.
71
78
 
72
79
  ## Getting started
73
80
 
74
- It's not yet on PyPI, so install it with `pip install sindy_exp @ git+https://github.com/Jacob-Stevens-Haas/sindy-experiments`
81
+ Install with `pip install sindy-exp` or `pip install sindy-exp[jax]`.
75
82
 
76
83
  Generate data
77
84
 
@@ -86,26 +93,26 @@ Evaluate your SINDy-like model with:
86
93
  A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
87
94
  of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
88
95
 
89
- ## ODE representation
96
+ ## ODE & Data Model
97
+
98
+ Generated or measured data has the dataclass type `ProbData` or `SimProbData`, respectively,
99
+ to indicate whether it includes ground truth information and a noise level.
100
+ If the data is generated in jax, it will have an integrator that can later be used to evaluate the true data on collocation points.
90
101
 
91
102
  We deal primarily with autonomous ODE systems of the form:
92
103
 
93
104
  dx/dt = sum_i f_i(x)
94
105
 
95
- Thus, we represent ODE systems as a list of right-hand side expressions.
106
+ We represent ODE systems as a list of right-hand side expressions.
96
107
  Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
108
+ Thus, the rhs of an ODE is of type: `list[dict[sympy.Expr, float]]`
97
109
 
98
110
  ## Other useful imports, compatibility, and extensions
99
111
 
100
- This is built to be compatible with dynamics learning models that follow the
101
- pysindy _BaseSINDy interface.
102
- The experiments are also built to be compatible with the `mitosis` tool,
103
- an experiment runner.
104
- To integrate your own experiments or data generation in a way that is compatible,
105
- see the `ProbData` and `DynamicsTrialData` classes.
106
- For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`,
107
- `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
108
- For metrics, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
112
+ * The experiments are built to be compatible with the `mitosis` tool, an experiment runner. Mitosis is not a dependency, however, to allow using other experiment runners.
113
+ * To integrate your own experiments or data generation in a way that is compatible, see the `ProbData`, `SimProbData`, `DynamicsTrialData`, and `FullDynamicsTrialData` classes.
114
+ * For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`, `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
115
+ * For evaluation of models, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
109
116
 
110
117
  ![3d plot](images/composite.png)
111
118
  ![1d plot](images/1d.png)
@@ -0,0 +1,15 @@
1
+ sindy_exp/__init__.py,sha256=DJuJM0QVe7zmpOHY--7U8m4MCU8XkL_WKGD1H4skSM0,773
2
+ sindy_exp/_data.py,sha256=GwGAnLvEgxRjrwGLTe8RC6GAIOocJtDBey1Dhj8Gyk4,9744
3
+ sindy_exp/_diffrax_solver.py,sha256=HvhmUT9u9A06c6XRSgewdbKl0JEVB312UMmI_Lb4DVo,2634
4
+ sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
5
+ sindy_exp/_odes.py,sha256=E7BdvFoPI7kxNh0Q9WMpE7MsKGbke5zDjX5Ml31amkQ,5927
6
+ sindy_exp/_plotting.py,sha256=mfsDz_693XQS6tMZ8iQluRJqZMcmxyoZeIMP4DfJX2s,17688
7
+ sindy_exp/_typing.py,sha256=hxIxBx52oKMjd1LA8PItwrXlP-PtNGLoZm6oIMZi18I,4536
8
+ sindy_exp/_utils.py,sha256=RRFtk1iFWeo0Rpw1vYhoNSoABLJcRmGo1sIgTZ9lnWM,11299
9
+ sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
10
+ sindy_exp/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ sindy_exp-0.3.0.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
12
+ sindy_exp-0.3.0.dist-info/METADATA,sha256=qTE2umRXW5K9mCtTIt2u-vvtht92rxII90K94ZYaPOE,5669
13
+ sindy_exp-0.3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
14
+ sindy_exp-0.3.0.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
15
+ sindy_exp-0.3.0.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- sindy_exp/__init__.py,sha256=F1bJz9Gzk2nPB6DGjEa0qZcaQPTb-Yhh0ZnO9TcQci0,689
2
- sindy_exp/_data.py,sha256=X1WdFZ53FeqG1fwtnxQnCHUQqKtGSAvv3URWdcSavE8,6511
3
- sindy_exp/_diffrax_solver.py,sha256=YmKF-U1fkxW4z4AbXW_qm1sX6eEy1BGzKbqL63dxo2M,2705
4
- sindy_exp/_dysts_to_sympy.py,sha256=d_rvnfayOmFcGn4bZRJCfNGFO6yS1mw2QmBbOdWZwxg,15654
5
- sindy_exp/_odes.py,sha256=RN97nh1Y_6DQCLPfRTyCCJq3exxBq9pfVHVPTxr0na4,8830
6
- sindy_exp/_plotting.py,sha256=dpcqAXKzb0mSVl0p2WyMadVoGhQi43oL6ZqsbuheEuk,17470
7
- sindy_exp/_typing.py,sha256=hIOj270YkoITi_qJ8xTO1EaBeoVzcDidIGMXuC0QHzM,4046
8
- sindy_exp/_utils.py,sha256=RnMoNV_N7ubWVMbgdiJaRGo437DvWUcOIn4fP2rhwJI,12717
9
- sindy_exp/addl_attractors.json,sha256=KXoHWekFoa4KctjLCqcj_BpLBhXV0zlYrpgxV-uObwE,2928
10
- sindy_exp-0.2.1.dist-info/licenses/LICENSE,sha256=ubi77tIG3RVrqo0Z8cK91D4KZePQs-W1J-vJ-LkVOmE,1075
11
- sindy_exp-0.2.1.dist-info/METADATA,sha256=deZyEMmZibk24ZMU1Fc7s_qD-ibRwRSjfoNqW1enqpk,4514
12
- sindy_exp-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
13
- sindy_exp-0.2.1.dist-info/top_level.txt,sha256=0-tKKdmxHG3IRccz463rOb6xTsVJD-v9c8zSDpTRr5E,10
14
- sindy_exp-0.2.1.dist-info/RECORD,,