sindy-exp 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/PKG-INFO +1 -1
  2. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/pyproject.toml +3 -0
  3. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/__init__.py +2 -1
  4. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_data.py +119 -15
  5. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_diffrax_solver.py +0 -17
  6. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_odes.py +28 -112
  7. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_typing.py +9 -0
  8. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_utils.py +10 -51
  9. sindy_exp-0.2.2/src/sindy_exp/py.typed +0 -0
  10. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp.egg-info/PKG-INFO +1 -1
  11. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp.egg-info/SOURCES.txt +1 -0
  12. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/tests/test_all.py +1 -1
  13. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/.github/workflows/main.yaml +0 -0
  14. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/.github/workflows/release.yml +0 -0
  15. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/.gitignore +0 -0
  16. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/.pre-commit-config.yaml +0 -0
  17. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/CITATION.cff +0 -0
  18. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/LICENSE +0 -0
  19. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/README.md +0 -0
  20. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/setup.cfg +0 -0
  21. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_dysts_to_sympy.py +0 -0
  22. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/_plotting.py +0 -0
  23. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp/addl_attractors.json +0 -0
  24. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp.egg-info/dependency_links.txt +0 -0
  25. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp.egg-info/requires.txt +0 -0
  26. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/src/sindy_exp.egg-info/top_level.txt +0 -0
  27. {sindy_exp-0.2.1 → sindy_exp-0.2.2}/tests/test_inspect_to_sympy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sindy-exp
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A basic library for constructing dynamics experiments
5
5
  Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
6
6
  License: MIT License
@@ -95,6 +95,9 @@ markers = ["slow"]
95
95
  files = [
96
96
  "src/sindy_exp/__init__.py",
97
97
  "src/sindy_exp/_utils.py",
98
+ "src/sindy_exp/_diffrax_solver.py",
99
+ "src/sindy_exp/_odes.py",
100
+ "src/sindy_exp/_typing.py",
98
101
  "tests/test_all.py",
99
102
  ]
100
103
 
@@ -7,7 +7,7 @@ from ._plotting import (
7
7
  plot_test_trajectory,
8
8
  plot_training_data,
9
9
  )
10
- from ._typing import DynamicsTrialData, ProbData
10
+ from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData
11
11
  from ._utils import coeff_metrics, integration_metrics, pred_metrics
12
12
 
13
13
  __all__ = [
@@ -15,6 +15,7 @@ __all__ = [
15
15
  "fit_eval",
16
16
  "ProbData",
17
17
  "DynamicsTrialData",
18
+ "FullDynamicsTrialData",
18
19
  "coeff_metrics",
19
20
  "pred_metrics",
20
21
  "integration_metrics",
@@ -1,15 +1,17 @@
1
+ from importlib import resources
1
2
  from logging import getLogger
2
- from typing import Any, Callable, Optional, cast
3
+ from typing import Callable, Optional, cast
3
4
 
4
5
  import dysts.flows
5
6
  import dysts.systems
6
7
  import numpy as np
7
8
  import scipy
9
+ import sympy as sp
10
+ from dysts.base import DynSys
8
11
 
9
12
  from ._dysts_to_sympy import dynsys_to_sympy
10
- from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
11
13
  from ._plotting import plot_training_data
12
- from ._typing import Float1D, ProbData
14
+ from ._typing import ExperimentResult, Float1D, ProbData
13
15
  from ._utils import _sympy_expr_to_feat_coeff
14
16
 
15
17
  try:
@@ -21,21 +23,12 @@ except ImportError:
21
23
 
22
24
  INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
23
25
  MOD_LOG = getLogger(__name__)
26
+ LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
24
27
 
25
28
  ODE_CLASSES = {
26
29
  klass.lower(): getattr(dysts.flows, klass)
27
30
  for klass in dysts.systems.get_attractor_list()
28
31
  }
29
- ODE_CLASSES.update(
30
- {
31
- "lotkavolterra": LotkaVolterra,
32
- "sho": SHO,
33
- "cubicho": CubicHO,
34
- "hopf": Hopf,
35
- "vanderpol": VanDerPol,
36
- "kinematics": Kinematics,
37
- }
38
- )
39
32
 
40
33
 
41
34
  def gen_data(
@@ -49,7 +42,7 @@ def gen_data(
49
42
  t_end: float = 10,
50
43
  display: bool = False,
51
44
  array_namespace: str = "numpy",
52
- ) -> dict[str, Any]:
45
+ ) -> ExperimentResult[tuple[list[ProbData], list[dict[sp.Expr, float]]]]:
53
46
  """Generate random training and test data
54
47
 
55
48
  An Experiment step according to the mitosis experiment runner.
@@ -146,7 +139,7 @@ def gen_data(
146
139
  figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
147
140
  figs[0].suptitle("Sample Trajectory")
148
141
  return {
149
- "data": {"trajectories": prob_data_list, "coeff_true": coeff_true},
142
+ "data": (prob_data_list, coeff_true),
150
143
  "main": f"{n_trajectories} trajectories of {rhsfunc}",
151
144
  "metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
152
145
  }
@@ -200,3 +193,114 @@ def _max_amplitude(signal: np.ndarray, axis: int) -> float:
200
193
 
201
194
  def _signal_avg_power(signal: np.ndarray) -> float:
202
195
  return np.square(signal).mean()
196
+
197
+
198
+ def _register_dyst(klass: type[DynSys]) -> type[DynSys]:
199
+ """Register a custom dysts DynSys class for use in sindy_exp data generation."""
200
+ ODE_CLASSES[klass.__name__.lower()] = klass
201
+ return klass
202
+
203
+
204
+ @_register_dyst
205
+ class LotkaVolterra(DynSys):
206
+ """Lotka-Volterra (predator-prey) dynamical system."""
207
+
208
+ nonnegative = True
209
+
210
+ def __init__(self):
211
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
212
+
213
+ @staticmethod
214
+ def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
215
+ """LV dynamics
216
+
217
+ Args:
218
+ x: prey population
219
+ y: predator population
220
+ t: time (ignored, since autonomous)
221
+ alpha: prey growth rate
222
+ beta: predation rate
223
+ delta: predator reproduction rate
224
+ gamma: predator death rate
225
+ """
226
+ dxdt = alpha * x - beta * x * y
227
+ dydt = delta * x * y - gamma * y
228
+ return np.array([dxdt, dydt])
229
+
230
+
231
+ @_register_dyst
232
+ class Hopf(DynSys):
233
+ """Hopf normal form dynamical system."""
234
+
235
+ def __init__(self):
236
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
237
+
238
+ @staticmethod
239
+ def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
240
+ dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
241
+ dydt = omega * x + mu * y - A * (x**2 * y + y**3)
242
+ return np.array([dxdt, dydt])
243
+
244
+
245
+ @_register_dyst
246
+ class SHO(DynSys):
247
+ """Linear damped simple harmonic oscillator"""
248
+
249
+ def __init__(self):
250
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
251
+
252
+ @staticmethod
253
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
254
+ dxdt = a * x + b * y
255
+ dydt = c * x + d * y
256
+ return np.array([dxdt, dydt])
257
+
258
+
259
+ @_register_dyst
260
+ class CubicHO(DynSys):
261
+ """Cubic damped harmonic oscillator."""
262
+
263
+ def __init__(self):
264
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
265
+
266
+ @staticmethod
267
+ def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
268
+ dxdt = a * x**3 + b * y**3
269
+ dydt = c * x**3 + d * y**3
270
+ return np.array([dxdt, dydt])
271
+
272
+
273
+ @_register_dyst
274
+ class VanDerPol(DynSys):
275
+ """Van der Pol oscillator.
276
+
277
+ dx/dt = y
278
+ dy/dt = mu * (1 - x^2) * y - x
279
+ """
280
+
281
+ def __init__(self):
282
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
283
+
284
+ @staticmethod
285
+ def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
286
+ dxdt = x_dot
287
+ dx2dt2 = mu * (1 - x**2) * x_dot - x
288
+ return np.array([dxdt, dx2dt2])
289
+
290
+
291
+ @_register_dyst
292
+ class Kinematics(DynSys):
293
+ """One-dimensional kinematics with constant acceleration.
294
+
295
+ dx/dt = v
296
+ dv/dt = a
297
+ """
298
+
299
+ def __init__(self):
300
+ super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
301
+
302
+ @staticmethod
303
+ def _rhs(x, v, t: float, a) -> np.ndarray:
304
+ dxdt = v
305
+ dvdt = a
306
+ return np.array([dxdt, dvdt])
@@ -85,20 +85,3 @@ def _gen_data_jax(
85
85
 
86
86
  def _signal_avg_power(signal: jax.Array) -> jax.Array:
87
87
  return jnp.square(signal).mean()
88
-
89
-
90
- ## % # noqa:E266
91
- if __name__ == "__main__":
92
- # Debug example
93
- from sindy_exp._data import gen_data
94
-
95
- data_dict = gen_data(
96
- "valliselnino",
97
- seed=50,
98
- n_trajectories=1,
99
- ic_stdev=3,
100
- noise_rel=0.1,
101
- display=True,
102
- array_namespace="jax",
103
- )
104
- print(data_dict["input_features"])
@@ -1,12 +1,9 @@
1
- from importlib import resources
2
1
  from logging import getLogger
3
- from typing import Callable, TypeVar
2
+ from typing import Any, Callable, Literal, TypeVar, cast, overload
4
3
 
5
4
  import matplotlib.pyplot as plt
6
5
  import numpy as np
7
- import pysindy as ps
8
6
  import sympy as sp
9
- from dysts.base import DynSys
10
7
 
11
8
  from ._plotting import (
12
9
  compare_coefficient_plots_from_dicts,
@@ -15,6 +12,7 @@ from ._plotting import (
15
12
  )
16
13
  from ._typing import (
17
14
  DynamicsTrialData,
15
+ ExperimentResult,
18
16
  FullDynamicsTrialData,
19
17
  ProbData,
20
18
  SINDyTrialUpdate,
@@ -41,7 +39,6 @@ metric_ordering = {
41
39
  T = TypeVar("T", bound=int)
42
40
  DType = TypeVar("DType", bound=np.dtype)
43
41
  MOD_LOG = getLogger(__name__)
44
- LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
45
42
 
46
43
 
47
44
  def _add_forcing(
@@ -68,112 +65,30 @@ def _add_forcing(
68
65
  return sum_of_terms
69
66
 
70
67
 
71
- class LotkaVolterra(DynSys):
72
- """Lotka-Volterra (predator-prey) dynamical system."""
73
-
74
- nonnegative = True
75
-
76
- def __init__(self):
77
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
78
-
79
- @staticmethod
80
- def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
81
- """LV dynamics
82
-
83
- Args:
84
- x: prey population
85
- y: predator population
86
- t: time (ignored, since autonomous)
87
- alpha: prey growth rate
88
- beta: predation rate
89
- delta: predator reproduction rate
90
- gamma: predator death rate
91
- """
92
- dxdt = alpha * x - beta * x * y
93
- dydt = delta * x * y - gamma * y
94
- return np.array([dxdt, dydt])
95
-
96
-
97
- class Hopf(DynSys):
98
- """Hopf normal form dynamical system."""
99
-
100
- def __init__(self):
101
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
102
-
103
- @staticmethod
104
- def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
105
- dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
106
- dydt = omega * x + mu * y - A * (x**2 * y + y**3)
107
- return np.array([dxdt, dydt])
108
-
109
-
110
- class SHO(DynSys):
111
- """Linear damped simple harmonic oscillator"""
112
-
113
- def __init__(self):
114
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
115
-
116
- @staticmethod
117
- def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
118
- dxdt = a * x + b * y
119
- dydt = c * x + d * y
120
- return np.array([dxdt, dydt])
121
-
122
-
123
- class CubicHO(DynSys):
124
- """Cubic damped harmonic oscillator."""
125
-
126
- def __init__(self):
127
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
128
-
129
- @staticmethod
130
- def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
131
- dxdt = a * x**3 + b * y**3
132
- dydt = c * x**3 + d * y**3
133
- return np.array([dxdt, dydt])
134
-
135
-
136
- class VanDerPol(DynSys):
137
- """Van der Pol oscillator.
138
-
139
- dx/dt = y
140
- dy/dt = mu * (1 - x^2) * y - x
141
- """
142
-
143
- def __init__(self):
144
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
145
-
146
- @staticmethod
147
- def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
148
- dxdt = x_dot
149
- dx2dt2 = mu * (1 - x**2) * x_dot - x
150
- return np.array([dxdt, dx2dt2])
151
-
152
-
153
- class Kinematics(DynSys):
154
- """One-dimensional kinematics with constant acceleration.
155
-
156
- dx/dt = v
157
- dv/dt = a
158
- """
68
+ @overload
69
+ def fit_eval(
70
+ data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
71
+ model: _BaseSINDy,
72
+ simulations: Literal[False],
73
+ display: bool,
74
+ ) -> ExperimentResult[DynamicsTrialData]: ...
159
75
 
160
- def __init__(self):
161
- super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
162
76
 
163
- @staticmethod
164
- def _rhs(x, v, t: float, a) -> np.ndarray:
165
- dxdt = v
166
- dvdt = a
167
- return np.array([dxdt, dvdt])
77
+ @overload
78
+ def fit_eval(
79
+ data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
80
+ model: _BaseSINDy,
81
+ simulations: Literal[True],
82
+ display: bool,
83
+ ) -> ExperimentResult[FullDynamicsTrialData]: ...
168
84
 
169
85
 
170
86
  def fit_eval(
171
87
  data: tuple[list[ProbData], list[dict[sp.Expr, float]]],
172
- model: _BaseSINDy,
88
+ model: Any,
173
89
  simulations: bool = True,
174
90
  display: bool = True,
175
- return_all: bool = True,
176
- ) -> dict | tuple[dict, DynamicsTrialData | FullDynamicsTrialData]:
91
+ ) -> ExperimentResult:
177
92
  """Fit and evaluate a SINDy model on a set of trajectories.
178
93
 
179
94
  Args:
@@ -184,11 +99,8 @@ def fit_eval(
184
99
  model: A SINDy-like model implementing the _BaseSINDy protocol.
185
100
  simulations: Whether to run forward simulations for evaluation.
186
101
  display: Whether to generate plots as part of evaluation.
187
- return_all: If True, return a dictionary containing metrics and the
188
- assembled DynamicsTrialData; otherwise return only the metrics
189
- dictionary.
190
102
  """
191
-
103
+ model = cast(_BaseSINDy, model)
192
104
  trajectories, true_equations = data
193
105
  input_features = trajectories[0].input_features
194
106
 
@@ -198,12 +110,16 @@ def fit_eval(
198
110
 
199
111
  MOD_LOG.info(f"Fitting a model: {model}")
200
112
  coeff_true_dicts, coeff_est_dicts = unionize_coeff_dicts(model, true_equations)
201
- if isinstance(model.feature_library, ps.WeakPDELibrary):
113
+
114
+ # Special workaround for pysindy's legacy WeakPDELibrary
115
+ if hasattr(model.feature_library, "K"):
202
116
  # WeakPDE library fails to simulate, so insert nonweak library
203
117
  # to Pipeline and SINDy model.
204
118
  inner_lib = model.feature_library.function_library
205
119
  model.feature_library = inner_lib # type: ignore # TODO: Fix in pysindy
206
- if isinstance(model, ps.SINDy) and hasattr(
120
+
121
+ # Special workaround for pysindy's bad (soon to be legacy) differentiation API
122
+ if hasattr(model, "differentiation_method") and hasattr(
207
123
  model.differentiation_method, "smoothed_x_"
208
124
  ):
209
125
  smooth_x = []
@@ -212,6 +128,7 @@ def fit_eval(
212
128
  smooth_x.append(model.differentiation_method.smoothed_x_)
213
129
  else: # using WeakPDELibrary
214
130
  smooth_x = x_train
131
+
215
132
  trial_data = DynamicsTrialData(
216
133
  trajectories=trajectories,
217
134
  true_equations=coeff_true_dicts,
@@ -264,9 +181,8 @@ def fit_eval(
264
181
  figs=(fig_composite, fig_by_coord_1d),
265
182
  coord_names=input_features,
266
183
  )
267
- if return_all:
268
- return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
269
- return metrics
184
+
185
+ return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
270
186
 
271
187
 
272
188
  def plot_ode_panel(trial_data: DynamicsTrialData):
@@ -7,6 +7,7 @@ from typing import (
7
7
  NamedTuple,
8
8
  Optional,
9
9
  Protocol,
10
+ TypedDict,
10
11
  TypeVar,
11
12
  overload,
12
13
  )
@@ -27,6 +28,14 @@ FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
27
28
  TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
28
29
 
29
30
 
31
+ class ExperimentResult[T](TypedDict):
32
+ """Results from a SINDy ODE experiment."""
33
+
34
+ metrics: float
35
+ data: T
36
+ main: float
37
+
38
+
30
39
  class _BaseSINDy(Protocol):
31
40
  optimizer: Any
32
41
  feature_library: Any
@@ -182,14 +182,18 @@ def coeff_metrics(
182
182
  coefficients[row_ind, col_ind] = est_row[feat]
183
183
 
184
184
  metrics: dict[str, float | np.floating] = {}
185
- metrics["coeff_precision"] = sklearn.metrics.precision_score(
186
- coeff_true.flatten() != 0, coefficients.flatten() != 0
185
+ metrics["coeff_precision"] = float(
186
+ sklearn.metrics.precision_score(
187
+ coeff_true.flatten() != 0, coefficients.flatten() != 0
188
+ )
187
189
  )
188
- metrics["coeff_recall"] = sklearn.metrics.recall_score(
189
- coeff_true.flatten() != 0, coefficients.flatten() != 0
190
+ metrics["coeff_recall"] = float(
191
+ sklearn.metrics.recall_score(
192
+ coeff_true.flatten() != 0, coefficients.flatten() != 0
193
+ )
190
194
  )
191
- metrics["coeff_f1"] = sklearn.metrics.f1_score(
192
- coeff_true.flatten() != 0, coefficients.flatten() != 0
195
+ metrics["coeff_f1"] = float(
196
+ sklearn.metrics.f1_score(coeff_true.flatten() != 0, coefficients.flatten() != 0)
193
197
  )
194
198
  metrics["coeff_mse"] = sklearn.metrics.mean_squared_error(
195
199
  coeff_true.flatten(), coefficients.flatten()
@@ -279,51 +283,6 @@ def unionize_coeff_dicts(
279
283
  return true_aligned, est_aligned
280
284
 
281
285
 
282
- def make_model(
283
- input_features: list[str],
284
- dt: float,
285
- diff_params: dict | ps.BaseDifferentiation,
286
- feat_params: dict | ps.feature_library.base.BaseFeatureLibrary,
287
- opt_params: dict | ps.BaseOptimizer,
288
- ) -> ps.SINDy:
289
- """Build a model with object parameters dictionaries
290
-
291
- e.g. {"kind": "finitedifference"} instead of FiniteDifference()
292
- """
293
-
294
- def finalize_param(lookup_func, pdict, lookup_key):
295
- try:
296
- cls_name = pdict.pop(lookup_key)
297
- except AttributeError:
298
- cls_name = pdict.vals.pop(lookup_key)
299
- pdict = pdict.vals
300
-
301
- param_cls = lookup_func(cls_name)
302
- param_final = param_cls(**pdict)
303
- pdict[lookup_key] = cls_name
304
- return param_final
305
-
306
- if isinstance(diff_params, ps.BaseDifferentiation):
307
- diff = diff_params
308
- else:
309
- diff = finalize_param(diff_lookup, diff_params, "diffcls")
310
- if isinstance(feat_params, ps.feature_library.base.BaseFeatureLibrary):
311
- features = feat_params
312
- else:
313
- features = finalize_param(feature_lookup, feat_params, "featcls")
314
- if isinstance(opt_params, ps.BaseOptimizer):
315
- opt = opt_params
316
- else:
317
- opt = finalize_param(opt_lookup, opt_params, "optcls")
318
- return ps.SINDy(
319
- differentiation_method=diff,
320
- optimizer=opt,
321
- t_default=dt, # type: ignore
322
- feature_library=features,
323
- feature_names=input_features,
324
- )
325
-
326
-
327
286
  def _simulate_test_data(
328
287
  model: _BaseSINDy, dt: float, x_test: Float2D
329
288
  ) -> SINDyTrialUpdate:
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sindy-exp
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A basic library for constructing dynamics experiments
5
5
  Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
6
6
  License: MIT License
@@ -16,6 +16,7 @@ src/sindy_exp/_plotting.py
16
16
  src/sindy_exp/_typing.py
17
17
  src/sindy_exp/_utils.py
18
18
  src/sindy_exp/addl_attractors.json
19
+ src/sindy_exp/py.typed
19
20
  src/sindy_exp.egg-info/PKG-INFO
20
21
  src/sindy_exp.egg-info/SOURCES.txt
21
22
  src/sindy_exp.egg-info/dependency_links.txt
@@ -180,7 +180,7 @@ def test_gen_data(rhs_name, array_namespace, jax_cpu_only):
180
180
  result = gen_data(
181
181
  rhs_name, t_end=0.1, noise_abs=0.01, seed=42, array_namespace=array_namespace
182
182
  )["data"]
183
- trajectories = result["trajectories"]
183
+ trajectories = result[0]
184
184
  assert len(trajectories) == 1
185
185
  traj = trajectories[0]
186
186
  assert traj.x_train.shape == traj.x_train_true_dot.shape
File without changes
File without changes
File without changes
File without changes
File without changes