sindy-exp 0.2.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/.github/workflows/main.yaml +3 -3
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/.github/workflows/release.yml +1 -1
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/.gitignore +0 -1
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/PKG-INFO +24 -17
- sindy_exp-0.3.0/README.md +51 -0
- sindy_exp-0.3.0/images/1d.png +0 -0
- sindy_exp-0.3.0/images/coeff.png +0 -0
- sindy_exp-0.3.0/images/composite.png +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/pyproject.toml +12 -13
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/__init__.py +3 -1
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_data.py +135 -26
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_diffrax_solver.py +12 -21
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_odes.py +36 -142
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_plotting.py +11 -5
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_typing.py +31 -7
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_utils.py +13 -55
- sindy_exp-0.3.0/src/sindy_exp/py.typed +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp.egg-info/PKG-INFO +24 -17
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp.egg-info/SOURCES.txt +4 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp.egg-info/requires.txt +1 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/tests/test_all.py +1 -1
- sindy_exp-0.2.1/README.md +0 -45
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/.pre-commit-config.yaml +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/CITATION.cff +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/LICENSE +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/setup.cfg +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/_dysts_to_sympy.py +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp/addl_attractors.json +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp.egg-info/dependency_links.txt +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/src/sindy_exp.egg-info/top_level.txt +0 -0
- {sindy_exp-0.2.1 → sindy_exp-0.3.0}/tests/test_inspect_to_sympy.py +0 -0
|
@@ -17,7 +17,7 @@ jobs:
|
|
|
17
17
|
- name: "Set up Python"
|
|
18
18
|
uses: actions/setup-python@v3
|
|
19
19
|
with:
|
|
20
|
-
python-version: "3.
|
|
20
|
+
python-version: "3.12"
|
|
21
21
|
- name: run pre-commit
|
|
22
22
|
run: |
|
|
23
23
|
pip install pre-commit
|
|
@@ -30,7 +30,7 @@ jobs:
|
|
|
30
30
|
- name: "Set up Python"
|
|
31
31
|
uses: actions/setup-python@v3
|
|
32
32
|
with:
|
|
33
|
-
python-version: "3.
|
|
33
|
+
python-version: "3.12"
|
|
34
34
|
- name: install dependencies
|
|
35
35
|
run: |
|
|
36
36
|
pip install -e .[dev,jax]
|
|
@@ -45,7 +45,7 @@ jobs:
|
|
|
45
45
|
fail-fast: false
|
|
46
46
|
max-parallel: 4
|
|
47
47
|
matrix:
|
|
48
|
-
python-version: ["3.
|
|
48
|
+
python-version: ["3.12", "3.13"]
|
|
49
49
|
steps:
|
|
50
50
|
- uses: actions/checkout@v3
|
|
51
51
|
- name: Set up Python ${{ matrix.python-version }}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sindy-exp
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A basic library for constructing dynamics experiments
|
|
5
5
|
Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -34,7 +34,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
34
34
|
Classifier: License :: OSI Approved :: MIT License
|
|
35
35
|
Classifier: Natural Language :: English
|
|
36
36
|
Classifier: Operating System :: POSIX :: Linux
|
|
37
|
-
Requires-Python: >=3.
|
|
37
|
+
Requires-Python: >=3.12
|
|
38
38
|
Description-Content-Type: text/markdown
|
|
39
39
|
License-File: LICENSE
|
|
40
40
|
Requires-Dist: matplotlib
|
|
@@ -43,6 +43,7 @@ Requires-Dist: seaborn
|
|
|
43
43
|
Requires-Dist: scipy
|
|
44
44
|
Requires-Dist: sympy
|
|
45
45
|
Requires-Dist: dysts
|
|
46
|
+
Requires-Dist: scikit-learn
|
|
46
47
|
Provides-Extra: jax
|
|
47
48
|
Requires-Dist: jax[cuda12]; extra == "jax"
|
|
48
49
|
Requires-Dist: diffrax; extra == "jax"
|
|
@@ -64,14 +65,20 @@ Requires-Dist: tomli; extra == "dev"
|
|
|
64
65
|
Requires-Dist: pysindy>=2.1.0; extra == "dev"
|
|
65
66
|
Dynamic: license-file
|
|
66
67
|
|
|
67
|
-
#
|
|
68
|
+
# Overview
|
|
68
69
|
|
|
69
|
-
A library for constructing dynamics experiments.
|
|
70
|
-
This includes data generation and
|
|
70
|
+
A library for constructing dynamics experiments from the dynamics models in the `dysts` package.
|
|
71
|
+
This includes data generation and model evaluation.
|
|
72
|
+
The first contribution is the static typing of trajectory data (`ProbData`) that, I believe, provides the necessary information to be useful in evaluating a wide variety of dynamics/time-series learning methods.
|
|
73
|
+
The second contribution is the collection of utility functions for designing dynamics learning experiments.
|
|
74
|
+
The third contribution is the collection of such experiments for evaluating dynamics/time-series learning models that meet the `BaseSINDy` API.
|
|
75
|
+
|
|
76
|
+
It aims to (a) be amenable to both `numpy` and `jax` arrays, (b) be usable by any dynamics/time-series learning models that meet the `BaseSINDy` or scikit-time API.
|
|
77
|
+
Internally, this package is used/will be used in benchmarking pysindy runtime/memory usage and choosing default hyperparameters.
|
|
71
78
|
|
|
72
79
|
## Getting started
|
|
73
80
|
|
|
74
|
-
|
|
81
|
+
Install with `pip install sindy-exp` or `pip install sindy-exp[jax]`.
|
|
75
82
|
|
|
76
83
|
Generate data
|
|
77
84
|
|
|
@@ -86,26 +93,26 @@ Evaluate your SINDy-like model with:
|
|
|
86
93
|
A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
|
|
87
94
|
of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
|
|
88
95
|
|
|
89
|
-
## ODE
|
|
96
|
+
## ODE & Data Model
|
|
97
|
+
|
|
98
|
+
Generated or measured data has the dataclass type `ProbData` or `SimProbData`, respectively,
|
|
99
|
+
to indicate whether it includes ground truth information and a noise level.
|
|
100
|
+
If the data is generated in jax, it will have an integrator that can later be used to evaluate the true data on collocation points.
|
|
90
101
|
|
|
91
102
|
We deal primarily with autonomous ODE systems of the form:
|
|
92
103
|
|
|
93
104
|
dx/dt = sum_i f_i(x)
|
|
94
105
|
|
|
95
|
-
|
|
106
|
+
We represent ODE systems as a list of right-hand side expressions.
|
|
96
107
|
Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
|
|
108
|
+
Thus, the rhs of an ODE is of type: `list[dict[sympy.Expr, float]]`
|
|
97
109
|
|
|
98
110
|
## Other useful imports, compatibility, and extensions
|
|
99
111
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
To integrate your own experiments or data generation in a way that is compatible,
|
|
105
|
-
see the `ProbData` and `DynamicsTrialData` classes.
|
|
106
|
-
For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`,
|
|
107
|
-
`plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
108
|
-
For metrics, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
112
|
+
* The experiments are built to be compatible with the `mitosis` tool, an experiment runner. Mitosis is not a dependency, however, to allow using other experiment runners.
|
|
113
|
+
* To integrate your own experiments or data generation in a way that is compatible, see the `ProbData`, `SimProbData`, `DynamicsTrialData`, and `FullDynamicsTrialData` classes.
|
|
114
|
+
* For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`, `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
115
|
+
* For evaluation of models, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
109
116
|
|
|
110
117
|

|
|
111
118
|

|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Overview
|
|
2
|
+
|
|
3
|
+
A library for constructing dynamics experiments from the dynamics models in the `dysts` package.
|
|
4
|
+
This includes data generation and model evaluation.
|
|
5
|
+
The first contribution is the static typing of trajectory data (`ProbData`) that, I believe, provides the necessary information to be useful in evaluating a wide variety of dynamics/time-series learning methods.
|
|
6
|
+
The second contribution is the collection of utility functions for designing dynamics learning experiments.
|
|
7
|
+
The third contribution is the collection of such experiments for evaluating dynamics/time-series learning models that meet the `BaseSINDy` API.
|
|
8
|
+
|
|
9
|
+
It aims to (a) be amenable to both `numpy` and `jax` arrays, (b) be usable by any dynamics/time-series learning models that meet the `BaseSINDy` or scikit-time API.
|
|
10
|
+
Internally, this package is used/will be used in benchmarking pysindy runtime/memory usage and choosing default hyperparameters.
|
|
11
|
+
|
|
12
|
+
## Getting started
|
|
13
|
+
|
|
14
|
+
Install with `pip install sindy-exp` or `pip install sindy-exp[jax]`.
|
|
15
|
+
|
|
16
|
+
Generate data
|
|
17
|
+
|
|
18
|
+
data = sindy_exp.data.gen_data("lorenz", num_trajectories=5, t_end=10.0, dt=0.01)["data]
|
|
19
|
+
|
|
20
|
+
Evaluate your SINDy-like model with:
|
|
21
|
+
|
|
22
|
+
sindy_exp.odes.fit_eval(model, data)
|
|
23
|
+
|
|
24
|
+

|
|
25
|
+
|
|
26
|
+
A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
|
|
27
|
+
of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
|
|
28
|
+
|
|
29
|
+
## ODE & Data Model
|
|
30
|
+
|
|
31
|
+
Generated or measured data has the dataclass type `ProbData` or `SimProbData`, respectively,
|
|
32
|
+
to indicate whether it includes ground truth information and a noise level.
|
|
33
|
+
If the data is generated in jax, it will have an integrator that can later be used to evaluate the true data on collocation points.
|
|
34
|
+
|
|
35
|
+
We deal primarily with autonomous ODE systems of the form:
|
|
36
|
+
|
|
37
|
+
dx/dt = sum_i f_i(x)
|
|
38
|
+
|
|
39
|
+
We represent ODE systems as a list of right-hand side expressions.
|
|
40
|
+
Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
|
|
41
|
+
Thus, the rhs of an ODE is of type: `list[dict[sympy.Expr, float]]`
|
|
42
|
+
|
|
43
|
+
## Other useful imports, compatibility, and extensions
|
|
44
|
+
|
|
45
|
+
* The experiments are built to be compatible with the `mitosis` tool, an experiment runner. Mitosis is not a dependency, however, to allow using other experiment runners.
|
|
46
|
+
* To integrate your own experiments or data generation in a way that is compatible, see the `ProbData`, `SimProbData`, `DynamicsTrialData`, and `FullDynamicsTrialData` classes.
|
|
47
|
+
* For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`, `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
48
|
+
* For evaluation of models, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
49
|
+
|
|
50
|
+

|
|
51
|
+

|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -7,7 +7,7 @@ name = "sindy-exp"
|
|
|
7
7
|
dynamic = ["version"]
|
|
8
8
|
description = "A basic library for constructing dynamics experiments"
|
|
9
9
|
readme = "README.md"
|
|
10
|
-
requires-python = ">=3.
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
11
|
license = {file = "LICENSE"}
|
|
12
12
|
keywords = ["Machine Learning", "Science", "Mathematics", "Experiments"]
|
|
13
13
|
authors = [
|
|
@@ -29,6 +29,7 @@ dependencies = [
|
|
|
29
29
|
"scipy",
|
|
30
30
|
"sympy",
|
|
31
31
|
"dysts",
|
|
32
|
+
"scikit-learn",
|
|
32
33
|
]
|
|
33
34
|
|
|
34
35
|
[project.optional-dependencies]
|
|
@@ -94,22 +95,20 @@ markers = ["slow"]
|
|
|
94
95
|
[tool.mypy]
|
|
95
96
|
files = [
|
|
96
97
|
"src/sindy_exp/__init__.py",
|
|
98
|
+
"src/sindy_exp/_data.py",
|
|
97
99
|
"src/sindy_exp/_utils.py",
|
|
100
|
+
"src/sindy_exp/_diffrax_solver.py",
|
|
101
|
+
"src/sindy_exp/_odes.py",
|
|
102
|
+
"src/sindy_exp/_plotting.py",
|
|
103
|
+
"src/sindy_exp/_typing.py",
|
|
98
104
|
"tests/test_all.py",
|
|
99
105
|
]
|
|
106
|
+
warn_unused_configs = true
|
|
100
107
|
|
|
101
108
|
[[tool.mypy.overrides]]
|
|
102
|
-
module="
|
|
103
|
-
|
|
109
|
+
module = ["sindy_exp.*"]
|
|
110
|
+
disable_error_code = ["import-untyped"]
|
|
104
111
|
|
|
105
112
|
[[tool.mypy.overrides]]
|
|
106
|
-
module="pysindy.*"
|
|
107
|
-
ignore_missing_imports=true
|
|
108
|
-
|
|
109
|
-
[[tool.mypy.overrides]]
|
|
110
|
-
module="sympy.*"
|
|
111
|
-
ignore_missing_imports=true
|
|
112
|
-
|
|
113
|
-
[[tool.mypy.overrides]]
|
|
114
|
-
module="scipy.*"
|
|
115
|
-
ignore_missing_imports=true
|
|
113
|
+
module = ["pysindy.*", "sympy.*"]
|
|
114
|
+
ignore_missing_imports = true
|
|
@@ -7,14 +7,16 @@ from ._plotting import (
|
|
|
7
7
|
plot_test_trajectory,
|
|
8
8
|
plot_training_data,
|
|
9
9
|
)
|
|
10
|
-
from ._typing import DynamicsTrialData, ProbData
|
|
10
|
+
from ._typing import DynamicsTrialData, FullDynamicsTrialData, ProbData, SimProbData
|
|
11
11
|
from ._utils import coeff_metrics, integration_metrics, pred_metrics
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
14
|
"gen_data",
|
|
15
15
|
"fit_eval",
|
|
16
16
|
"ProbData",
|
|
17
|
+
"SimProbData",
|
|
17
18
|
"DynamicsTrialData",
|
|
19
|
+
"FullDynamicsTrialData",
|
|
18
20
|
"coeff_metrics",
|
|
19
21
|
"pred_metrics",
|
|
20
22
|
"integration_metrics",
|
|
@@ -1,15 +1,17 @@
|
|
|
1
|
+
from importlib import resources
|
|
1
2
|
from logging import getLogger
|
|
2
|
-
from typing import
|
|
3
|
+
from typing import Callable, Optional, cast
|
|
3
4
|
|
|
4
5
|
import dysts.flows
|
|
5
6
|
import dysts.systems
|
|
6
7
|
import numpy as np
|
|
7
8
|
import scipy
|
|
9
|
+
import sympy as sp
|
|
10
|
+
from dysts.base import DynSys
|
|
8
11
|
|
|
9
12
|
from ._dysts_to_sympy import dynsys_to_sympy
|
|
10
|
-
from ._odes import SHO, CubicHO, Hopf, Kinematics, LotkaVolterra, VanDerPol
|
|
11
13
|
from ._plotting import plot_training_data
|
|
12
|
-
from ._typing import Float1D,
|
|
14
|
+
from ._typing import ExperimentResult, Float1D, SimProbData
|
|
13
15
|
from ._utils import _sympy_expr_to_feat_coeff
|
|
14
16
|
|
|
15
17
|
try:
|
|
@@ -21,21 +23,12 @@ except ImportError:
|
|
|
21
23
|
|
|
22
24
|
INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
|
|
23
25
|
MOD_LOG = getLogger(__name__)
|
|
26
|
+
LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
|
|
24
27
|
|
|
25
28
|
ODE_CLASSES = {
|
|
26
29
|
klass.lower(): getattr(dysts.flows, klass)
|
|
27
30
|
for klass in dysts.systems.get_attractor_list()
|
|
28
31
|
}
|
|
29
|
-
ODE_CLASSES.update(
|
|
30
|
-
{
|
|
31
|
-
"lotkavolterra": LotkaVolterra,
|
|
32
|
-
"sho": SHO,
|
|
33
|
-
"cubicho": CubicHO,
|
|
34
|
-
"hopf": Hopf,
|
|
35
|
-
"vanderpol": VanDerPol,
|
|
36
|
-
"kinematics": Kinematics,
|
|
37
|
-
}
|
|
38
|
-
)
|
|
39
32
|
|
|
40
33
|
|
|
41
34
|
def gen_data(
|
|
@@ -49,7 +42,7 @@ def gen_data(
|
|
|
49
42
|
t_end: float = 10,
|
|
50
43
|
display: bool = False,
|
|
51
44
|
array_namespace: str = "numpy",
|
|
52
|
-
) -> dict[
|
|
45
|
+
) -> ExperimentResult[tuple[list[SimProbData], list[dict[sp.Expr, float]]]]:
|
|
53
46
|
"""Generate random training and test data
|
|
54
47
|
|
|
55
48
|
An Experiment step according to the mitosis experiment runner.
|
|
@@ -82,7 +75,7 @@ def gen_data(
|
|
|
82
75
|
coeff_true = _sympy_expr_to_feat_coeff(sp_expr)
|
|
83
76
|
rhsfunc = lambda t, X: dyst_sys.rhs(X, t) # noqa: E731
|
|
84
77
|
try:
|
|
85
|
-
x0_center = dyst_sys.ic
|
|
78
|
+
x0_center = cast(Float1D, dyst_sys.ic)
|
|
86
79
|
except KeyError:
|
|
87
80
|
x0_center = np.zeros((len(input_features)), dtype=np.float64)
|
|
88
81
|
try:
|
|
@@ -95,7 +88,7 @@ def gen_data(
|
|
|
95
88
|
noise_abs = 0.1
|
|
96
89
|
|
|
97
90
|
MOD_LOG.info(f"Generating {n_trajectories} trajectories of f{system}")
|
|
98
|
-
prob_data_list: list[
|
|
91
|
+
prob_data_list: list[SimProbData] = []
|
|
99
92
|
if array_namespace == "numpy":
|
|
100
93
|
feature_names = [feat.name for feat in input_features]
|
|
101
94
|
for _ in range(n_trajectories):
|
|
@@ -115,20 +108,20 @@ def gen_data(
|
|
|
115
108
|
prob_data_list.append(prob)
|
|
116
109
|
elif array_namespace == "jax":
|
|
117
110
|
try:
|
|
118
|
-
|
|
119
|
-
except
|
|
111
|
+
jax # type: ignore
|
|
112
|
+
except NameError:
|
|
120
113
|
raise ImportError(
|
|
121
114
|
"jax data generation requested but diffrax or sympy2jax not"
|
|
122
115
|
" installed"
|
|
123
116
|
)
|
|
124
|
-
this_seed = jax.random.PRNGKey(seed)
|
|
117
|
+
this_seed = jax.random.PRNGKey(seed) # type: ignore
|
|
125
118
|
for _ in range(n_trajectories):
|
|
126
|
-
this_seed, _ = jax.random.split(this_seed)
|
|
127
|
-
prob = _gen_data_jax(
|
|
119
|
+
this_seed, _ = jax.random.split(this_seed) # type: ignore
|
|
120
|
+
prob = _gen_data_jax( # type: ignore
|
|
128
121
|
sp_expr,
|
|
129
122
|
input_features,
|
|
130
123
|
this_seed,
|
|
131
|
-
x0_center=x0_center,
|
|
124
|
+
x0_center=x0_center, # type: ignore # numpy->jax
|
|
132
125
|
nonnegative=nonnegative,
|
|
133
126
|
ic_stdev=ic_stdev,
|
|
134
127
|
noise_abs=noise_abs,
|
|
@@ -143,10 +136,12 @@ def gen_data(
|
|
|
143
136
|
)
|
|
144
137
|
if display and prob_data_list:
|
|
145
138
|
sample = prob_data_list[0]
|
|
139
|
+
assert sample.x_train_true is not None # typing
|
|
146
140
|
figs = plot_training_data(sample.t_train, sample.x_train, sample.x_train_true)
|
|
147
141
|
figs[0].suptitle("Sample Trajectory")
|
|
142
|
+
|
|
148
143
|
return {
|
|
149
|
-
"data":
|
|
144
|
+
"data": (prob_data_list, coeff_true),
|
|
150
145
|
"main": f"{n_trajectories} trajectories of {rhsfunc}",
|
|
151
146
|
"metrics": {"rel_noise": noise_rel, "abs_noise": noise_abs},
|
|
152
147
|
}
|
|
@@ -163,7 +158,7 @@ def _gen_data(
|
|
|
163
158
|
nonnegative: bool,
|
|
164
159
|
dt: float,
|
|
165
160
|
t_end: float,
|
|
166
|
-
) ->
|
|
161
|
+
) -> SimProbData:
|
|
167
162
|
rng = np.random.default_rng(seed)
|
|
168
163
|
t_train = np.arange(0, t_end, dt)
|
|
169
164
|
t_train_span = (t_train[0], t_train[-1])
|
|
@@ -187,8 +182,9 @@ def _gen_data(
|
|
|
187
182
|
noise_abs = np.sqrt(_signal_avg_power(x_train) * noise_rel)
|
|
188
183
|
x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
|
|
189
184
|
|
|
190
|
-
|
|
191
|
-
|
|
185
|
+
assert noise_abs is not None # typing
|
|
186
|
+
return SimProbData(
|
|
187
|
+
t_train, x_train, input_features, x_train_true, x_train_true_dot, noise_abs
|
|
192
188
|
)
|
|
193
189
|
|
|
194
190
|
|
|
@@ -200,3 +196,116 @@ def _max_amplitude(signal: np.ndarray, axis: int) -> float:
|
|
|
200
196
|
|
|
201
197
|
def _signal_avg_power(signal: np.ndarray) -> float:
|
|
202
198
|
return np.square(signal).mean()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _register_dyst(klass: type[DynSys]) -> type[DynSys]:
|
|
202
|
+
"""Register a custom dysts DynSys class for use in sindy_exp data generation."""
|
|
203
|
+
ODE_CLASSES[klass.__name__.lower()] = klass
|
|
204
|
+
return klass
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@_register_dyst
|
|
208
|
+
class LotkaVolterra(DynSys):
|
|
209
|
+
"""Lotka-Volterra (predator-prey) dynamical system."""
|
|
210
|
+
|
|
211
|
+
nonnegative = True
|
|
212
|
+
|
|
213
|
+
def __init__(self):
|
|
214
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def _rhs( # type: ignore # dysts
|
|
218
|
+
x, y, t: float, alpha, beta, gamma, delta
|
|
219
|
+
) -> np.ndarray:
|
|
220
|
+
"""LV dynamics
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
x: prey population
|
|
224
|
+
y: predator population
|
|
225
|
+
t: time (ignored, since autonomous)
|
|
226
|
+
alpha: prey growth rate
|
|
227
|
+
beta: predation rate
|
|
228
|
+
delta: predator reproduction rate
|
|
229
|
+
gamma: predator death rate
|
|
230
|
+
"""
|
|
231
|
+
dxdt = alpha * x - beta * x * y
|
|
232
|
+
dydt = delta * x * y - gamma * y
|
|
233
|
+
return np.array([dxdt, dydt])
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@_register_dyst
|
|
237
|
+
class Hopf(DynSys):
|
|
238
|
+
"""Hopf normal form dynamical system."""
|
|
239
|
+
|
|
240
|
+
def __init__(self):
|
|
241
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray: # type: ignore # dysts
|
|
245
|
+
dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
|
|
246
|
+
dydt = omega * x + mu * y - A * (x**2 * y + y**3)
|
|
247
|
+
return np.array([dxdt, dydt])
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@_register_dyst
|
|
251
|
+
class SHO(DynSys):
|
|
252
|
+
"""Linear damped simple harmonic oscillator"""
|
|
253
|
+
|
|
254
|
+
def __init__(self):
|
|
255
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
|
|
259
|
+
dxdt = a * x + b * y
|
|
260
|
+
dydt = c * x + d * y
|
|
261
|
+
return np.array([dxdt, dydt])
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@_register_dyst
|
|
265
|
+
class CubicHO(DynSys):
|
|
266
|
+
"""Cubic damped harmonic oscillator."""
|
|
267
|
+
|
|
268
|
+
def __init__(self):
|
|
269
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray: # type: ignore # dysts
|
|
273
|
+
dxdt = a * x**3 + b * y**3
|
|
274
|
+
dydt = c * x**3 + d * y**3
|
|
275
|
+
return np.array([dxdt, dydt])
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@_register_dyst
|
|
279
|
+
class VanDerPol(DynSys):
|
|
280
|
+
"""Van der Pol oscillator.
|
|
281
|
+
|
|
282
|
+
dx/dt = y
|
|
283
|
+
dy/dt = mu * (1 - x^2) * y - x
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self):
|
|
287
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _rhs(x, x_dot, t: float, mu) -> np.ndarray: # type: ignore # dysts
|
|
291
|
+
dxdt = x_dot
|
|
292
|
+
dx2dt2 = mu * (1 - x**2) * x_dot - x
|
|
293
|
+
return np.array([dxdt, dx2dt2])
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@_register_dyst
|
|
297
|
+
class Kinematics(DynSys):
|
|
298
|
+
"""One-dimensional kinematics with constant acceleration.
|
|
299
|
+
|
|
300
|
+
dx/dt = v
|
|
301
|
+
dv/dt = a
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(self):
|
|
305
|
+
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH) # type: ignore # dysts
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _rhs(x, v, t: float, a) -> np.ndarray: # type: ignore # dysts
|
|
309
|
+
dxdt = v
|
|
310
|
+
dvdt = a
|
|
311
|
+
return np.array([dxdt, dvdt])
|
|
@@ -6,7 +6,7 @@ import jax.numpy as jnp
|
|
|
6
6
|
import sympy2jax
|
|
7
7
|
from sympy import Expr, Symbol
|
|
8
8
|
|
|
9
|
-
from ._typing import
|
|
9
|
+
from ._typing import SimProbData
|
|
10
10
|
|
|
11
11
|
jax.config.update("jax_enable_x64", True)
|
|
12
12
|
|
|
@@ -22,7 +22,7 @@ def _gen_data_jax(
|
|
|
22
22
|
nonnegative: bool,
|
|
23
23
|
dt: float,
|
|
24
24
|
t_end: float,
|
|
25
|
-
) ->
|
|
25
|
+
) -> SimProbData:
|
|
26
26
|
rhstree = sympy2jax.SymbolicModule(exprs)
|
|
27
27
|
|
|
28
28
|
def ode_sys(t, state, args):
|
|
@@ -71,6 +71,8 @@ def _gen_data_jax(
|
|
|
71
71
|
if noise_abs is None:
|
|
72
72
|
assert noise_rel is not None # force type narrowing
|
|
73
73
|
noise_abs = float(jnp.sqrt(_signal_avg_power(x_train_true)) * noise_rel)
|
|
74
|
+
else:
|
|
75
|
+
noise_rel = noise_abs / float(jnp.sqrt(_signal_avg_power(x_train_true)))
|
|
74
76
|
|
|
75
77
|
x_train = x_train_true + jax.random.normal(key, x_train_true.shape) * noise_abs
|
|
76
78
|
|
|
@@ -78,27 +80,16 @@ def _gen_data_jax(
|
|
|
78
80
|
x_train_true_dot = jnp.array([ode_sys(0, xi, None) for xi in x_train_true])
|
|
79
81
|
|
|
80
82
|
stringy_features = [sym.name for sym in input_features]
|
|
81
|
-
return
|
|
82
|
-
|
|
83
|
+
return SimProbData(
|
|
84
|
+
t_train, # type: ignore # jax->numpy
|
|
85
|
+
x_train, # type: ignore # jax->numpy
|
|
86
|
+
stringy_features,
|
|
87
|
+
x_train_true, # type: ignore # jax->numpy
|
|
88
|
+
x_train_true_dot, # type: ignore # jax->numpy
|
|
89
|
+
noise_abs,
|
|
90
|
+
sol,
|
|
83
91
|
)
|
|
84
92
|
|
|
85
93
|
|
|
86
94
|
def _signal_avg_power(signal: jax.Array) -> jax.Array:
|
|
87
95
|
return jnp.square(signal).mean()
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
## % # noqa:E266
|
|
91
|
-
if __name__ == "__main__":
|
|
92
|
-
# Debug example
|
|
93
|
-
from sindy_exp._data import gen_data
|
|
94
|
-
|
|
95
|
-
data_dict = gen_data(
|
|
96
|
-
"valliselnino",
|
|
97
|
-
seed=50,
|
|
98
|
-
n_trajectories=1,
|
|
99
|
-
ic_stdev=3,
|
|
100
|
-
noise_rel=0.1,
|
|
101
|
-
display=True,
|
|
102
|
-
array_namespace="jax",
|
|
103
|
-
)
|
|
104
|
-
print(data_dict["input_features"])
|
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from importlib import resources
|
|
2
1
|
from logging import getLogger
|
|
3
|
-
from typing import
|
|
2
|
+
from typing import Any, Literal, TypeVar, cast, overload
|
|
4
3
|
|
|
5
4
|
import matplotlib.pyplot as plt
|
|
6
5
|
import numpy as np
|
|
7
|
-
import pysindy as ps
|
|
8
6
|
import sympy as sp
|
|
9
|
-
from dysts.base import DynSys
|
|
10
7
|
|
|
11
8
|
from ._plotting import (
|
|
12
9
|
compare_coefficient_plots_from_dicts,
|
|
@@ -15,8 +12,9 @@ from ._plotting import (
|
|
|
15
12
|
)
|
|
16
13
|
from ._typing import (
|
|
17
14
|
DynamicsTrialData,
|
|
15
|
+
ExperimentResult,
|
|
18
16
|
FullDynamicsTrialData,
|
|
19
|
-
|
|
17
|
+
SimProbData,
|
|
20
18
|
SINDyTrialUpdate,
|
|
21
19
|
_BaseSINDy,
|
|
22
20
|
)
|
|
@@ -41,154 +39,46 @@ metric_ordering = {
|
|
|
41
39
|
T = TypeVar("T", bound=int)
|
|
42
40
|
DType = TypeVar("DType", bound=np.dtype)
|
|
43
41
|
MOD_LOG = getLogger(__name__)
|
|
44
|
-
LOCAL_DYNAMICS_PATH = resources.files("sindy_exp").joinpath("addl_attractors.json")
|
|
45
42
|
|
|
46
43
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
],
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
forcing_func: The forcing function to add
|
|
57
|
-
auto_func: An existing rhs func for solve_ivp
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
A rhs function for integration
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
def sum_of_terms(
|
|
64
|
-
t: float, state: np.ndarray[tuple[T], DType]
|
|
65
|
-
) -> np.ndarray[tuple[T], DType]:
|
|
66
|
-
return np.array(forcing_func(t)) + np.array(auto_func(t, state))
|
|
67
|
-
|
|
68
|
-
return sum_of_terms
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
class LotkaVolterra(DynSys):
|
|
72
|
-
"""Lotka-Volterra (predator-prey) dynamical system."""
|
|
73
|
-
|
|
74
|
-
nonnegative = True
|
|
75
|
-
|
|
76
|
-
def __init__(self):
|
|
77
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
78
|
-
|
|
79
|
-
@staticmethod
|
|
80
|
-
def _rhs(x, y, t: float, alpha, beta, gamma, delta) -> np.ndarray:
|
|
81
|
-
"""LV dynamics
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
x: prey population
|
|
85
|
-
y: predator population
|
|
86
|
-
t: time (ignored, since autonomous)
|
|
87
|
-
alpha: prey growth rate
|
|
88
|
-
beta: predation rate
|
|
89
|
-
delta: predator reproduction rate
|
|
90
|
-
gamma: predator death rate
|
|
91
|
-
"""
|
|
92
|
-
dxdt = alpha * x - beta * x * y
|
|
93
|
-
dydt = delta * x * y - gamma * y
|
|
94
|
-
return np.array([dxdt, dydt])
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class Hopf(DynSys):
|
|
98
|
-
"""Hopf normal form dynamical system."""
|
|
99
|
-
|
|
100
|
-
def __init__(self):
|
|
101
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
102
|
-
|
|
103
|
-
@staticmethod
|
|
104
|
-
def _rhs(x, y, t: float, mu, omega, A) -> np.ndarray:
|
|
105
|
-
dxdt = mu * x - omega * y - A * (x**3 + x * y**2)
|
|
106
|
-
dydt = omega * x + mu * y - A * (x**2 * y + y**3)
|
|
107
|
-
return np.array([dxdt, dydt])
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
class SHO(DynSys):
|
|
111
|
-
"""Linear damped simple harmonic oscillator"""
|
|
112
|
-
|
|
113
|
-
def __init__(self):
|
|
114
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
115
|
-
|
|
116
|
-
@staticmethod
|
|
117
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
118
|
-
dxdt = a * x + b * y
|
|
119
|
-
dydt = c * x + d * y
|
|
120
|
-
return np.array([dxdt, dydt])
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class CubicHO(DynSys):
|
|
124
|
-
"""Cubic damped harmonic oscillator."""
|
|
125
|
-
|
|
126
|
-
def __init__(self):
|
|
127
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
128
|
-
|
|
129
|
-
@staticmethod
|
|
130
|
-
def _rhs(x, y, t: float, a, b, c, d) -> np.ndarray:
|
|
131
|
-
dxdt = a * x**3 + b * y**3
|
|
132
|
-
dydt = c * x**3 + d * y**3
|
|
133
|
-
return np.array([dxdt, dydt])
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class VanDerPol(DynSys):
|
|
137
|
-
"""Van der Pol oscillator.
|
|
138
|
-
|
|
139
|
-
dx/dt = y
|
|
140
|
-
dy/dt = mu * (1 - x^2) * y - x
|
|
141
|
-
"""
|
|
142
|
-
|
|
143
|
-
def __init__(self):
|
|
144
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
145
|
-
|
|
146
|
-
@staticmethod
|
|
147
|
-
def _rhs(x, x_dot, t: float, mu) -> np.ndarray:
|
|
148
|
-
dxdt = x_dot
|
|
149
|
-
dx2dt2 = mu * (1 - x**2) * x_dot - x
|
|
150
|
-
return np.array([dxdt, dx2dt2])
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class Kinematics(DynSys):
|
|
154
|
-
"""One-dimensional kinematics with constant acceleration.
|
|
155
|
-
|
|
156
|
-
dx/dt = v
|
|
157
|
-
dv/dt = a
|
|
158
|
-
"""
|
|
44
|
+
@overload
|
|
45
|
+
def fit_eval(
|
|
46
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
47
|
+
model: _BaseSINDy,
|
|
48
|
+
simulations: Literal[False],
|
|
49
|
+
display: bool,
|
|
50
|
+
) -> ExperimentResult[DynamicsTrialData]: ...
|
|
159
51
|
|
|
160
|
-
def __init__(self):
|
|
161
|
-
super().__init__(metadata_path=LOCAL_DYNAMICS_PATH)
|
|
162
52
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
53
|
+
@overload
|
|
54
|
+
def fit_eval(
|
|
55
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
56
|
+
model: _BaseSINDy,
|
|
57
|
+
simulations: Literal[True],
|
|
58
|
+
display: bool,
|
|
59
|
+
) -> ExperimentResult[FullDynamicsTrialData]: ...
|
|
168
60
|
|
|
169
61
|
|
|
170
62
|
def fit_eval(
|
|
171
|
-
data: tuple[list[
|
|
172
|
-
model:
|
|
63
|
+
data: tuple[list[SimProbData], list[dict[sp.Expr, float]]],
|
|
64
|
+
model: Any,
|
|
173
65
|
simulations: bool = True,
|
|
174
66
|
display: bool = True,
|
|
175
|
-
|
|
176
|
-
) -> dict | tuple[dict, DynamicsTrialData | FullDynamicsTrialData]:
|
|
67
|
+
) -> ExperimentResult:
|
|
177
68
|
"""Fit and evaluate a SINDy model on a set of trajectories.
|
|
178
69
|
|
|
179
70
|
Args:
|
|
180
71
|
data: Tuple of (trajectories, true_equations), where ``trajectories`` is
|
|
181
|
-
a list of
|
|
72
|
+
a list of SimProbData objects and ``true_equations`` is a list of
|
|
182
73
|
dictionaries mapping SymPy symbols to their true coefficients for
|
|
183
74
|
each state coordinate.
|
|
184
75
|
model: A SINDy-like model implementing the _BaseSINDy protocol.
|
|
185
76
|
simulations: Whether to run forward simulations for evaluation.
|
|
186
77
|
display: Whether to generate plots as part of evaluation.
|
|
187
|
-
return_all: If True, return a dictionary containing metrics and the
|
|
188
|
-
assembled DynamicsTrialData; otherwise return only the metrics
|
|
189
|
-
dictionary.
|
|
190
78
|
"""
|
|
191
|
-
|
|
79
|
+
model = cast(_BaseSINDy, model)
|
|
80
|
+
for trajectory in data[0]:
|
|
81
|
+
assert trajectory.x_train_true is not None
|
|
192
82
|
trajectories, true_equations = data
|
|
193
83
|
input_features = trajectories[0].input_features
|
|
194
84
|
|
|
@@ -198,12 +88,16 @@ def fit_eval(
|
|
|
198
88
|
|
|
199
89
|
MOD_LOG.info(f"Fitting a model: {model}")
|
|
200
90
|
coeff_true_dicts, coeff_est_dicts = unionize_coeff_dicts(model, true_equations)
|
|
201
|
-
|
|
91
|
+
|
|
92
|
+
# Special workaround for pysindy's legacy WeakPDELibrary
|
|
93
|
+
if hasattr(model.feature_library, "K"):
|
|
202
94
|
# WeakPDE library fails to simulate, so insert nonweak library
|
|
203
95
|
# to Pipeline and SINDy model.
|
|
204
96
|
inner_lib = model.feature_library.function_library
|
|
205
97
|
model.feature_library = inner_lib # type: ignore # TODO: Fix in pysindy
|
|
206
|
-
|
|
98
|
+
|
|
99
|
+
# Special workaround for pysindy's bad (soon to be legacy) differentiation API
|
|
100
|
+
if hasattr(model, "differentiation_method") and hasattr(
|
|
207
101
|
model.differentiation_method, "smoothed_x_"
|
|
208
102
|
):
|
|
209
103
|
smooth_x = []
|
|
@@ -212,6 +106,7 @@ def fit_eval(
|
|
|
212
106
|
smooth_x.append(model.differentiation_method.smoothed_x_)
|
|
213
107
|
else: # using WeakPDELibrary
|
|
214
108
|
smooth_x = x_train
|
|
109
|
+
|
|
215
110
|
trial_data = DynamicsTrialData(
|
|
216
111
|
trajectories=trajectories,
|
|
217
112
|
true_equations=coeff_true_dicts,
|
|
@@ -226,7 +121,7 @@ def fit_eval(
|
|
|
226
121
|
sims: list[SINDyTrialUpdate] = []
|
|
227
122
|
integration_metric_list: list[dict[str, float | np.floating]] = []
|
|
228
123
|
for traj in trajectories:
|
|
229
|
-
sim = _simulate_test_data(model, traj.
|
|
124
|
+
sim = _simulate_test_data(model, traj.t_train, traj.x_train_true)
|
|
230
125
|
sims.append(sim)
|
|
231
126
|
integration_metric_list.append(
|
|
232
127
|
integration_metrics(
|
|
@@ -237,9 +132,9 @@ def fit_eval(
|
|
|
237
132
|
)
|
|
238
133
|
)
|
|
239
134
|
|
|
240
|
-
agg_integration_metrics: dict[str, float
|
|
135
|
+
agg_integration_metrics: dict[str, float] = {}
|
|
241
136
|
for key in integration_metric_list[0].keys():
|
|
242
|
-
values = [m[key] for m in integration_metric_list]
|
|
137
|
+
values = cast(list[float], [m[key] for m in integration_metric_list])
|
|
243
138
|
agg_integration_metrics[key] = float(np.mean(values))
|
|
244
139
|
metrics.update(agg_integration_metrics)
|
|
245
140
|
|
|
@@ -264,9 +159,8 @@ def fit_eval(
|
|
|
264
159
|
figs=(fig_composite, fig_by_coord_1d),
|
|
265
160
|
coord_names=input_features,
|
|
266
161
|
)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
return metrics
|
|
162
|
+
|
|
163
|
+
return {"metrics": metrics, "data": trial_data, "main": metrics["main"]}
|
|
270
164
|
|
|
271
165
|
|
|
272
166
|
def plot_ode_panel(trial_data: DynamicsTrialData):
|
|
@@ -57,7 +57,7 @@ def plot_coefficients(
|
|
|
57
57
|
feature_names: Sequence[str],
|
|
58
58
|
ax: Axes,
|
|
59
59
|
**heatmap_kws,
|
|
60
|
-
) ->
|
|
60
|
+
) -> Axes:
|
|
61
61
|
"""Plot a set of dynamical system coefficients in a heatmap.
|
|
62
62
|
|
|
63
63
|
Args:
|
|
@@ -162,6 +162,7 @@ def _compare_coefficient_plots_impl(
|
|
|
162
162
|
1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True
|
|
163
163
|
)
|
|
164
164
|
fig.tight_layout()
|
|
165
|
+
assert axs is not None # type narrowing
|
|
165
166
|
|
|
166
167
|
vmax = signed_root(max_val)
|
|
167
168
|
|
|
@@ -275,7 +276,12 @@ def _plot_training_trajectory(
|
|
|
275
276
|
"""
|
|
276
277
|
if x_train.shape[1] == 2:
|
|
277
278
|
ax.plot(
|
|
278
|
-
x_true[:, 0],
|
|
279
|
+
x_true[:, 0],
|
|
280
|
+
x_true[:, 1],
|
|
281
|
+
".",
|
|
282
|
+
label="True",
|
|
283
|
+
color=COLOR.TRUE,
|
|
284
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
279
285
|
)
|
|
280
286
|
ax.plot(
|
|
281
287
|
x_train[:, 0],
|
|
@@ -283,7 +289,7 @@ def _plot_training_trajectory(
|
|
|
283
289
|
".",
|
|
284
290
|
label="Measured",
|
|
285
291
|
color=COLOR.MEAS,
|
|
286
|
-
**PLOT_KWS,
|
|
292
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
287
293
|
)
|
|
288
294
|
if (
|
|
289
295
|
x_smooth is not None
|
|
@@ -295,7 +301,7 @@ def _plot_training_trajectory(
|
|
|
295
301
|
".",
|
|
296
302
|
label="Smoothed",
|
|
297
303
|
color=COLOR.EST,
|
|
298
|
-
**PLOT_KWS,
|
|
304
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
299
305
|
)
|
|
300
306
|
if labels:
|
|
301
307
|
ax.set(xlabel="$x_0$", ylabel="$x_1$")
|
|
@@ -308,7 +314,7 @@ def _plot_training_trajectory(
|
|
|
308
314
|
x_true[:, 2],
|
|
309
315
|
color=COLOR.TRUE,
|
|
310
316
|
label="True values",
|
|
311
|
-
**PLOT_KWS,
|
|
317
|
+
**PLOT_KWS, # type: ignore[arg-type]
|
|
312
318
|
)
|
|
313
319
|
|
|
314
320
|
ax.plot(
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
+
from collections.abc import Mapping
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import (
|
|
4
5
|
Any,
|
|
5
6
|
Callable,
|
|
6
7
|
Literal,
|
|
7
|
-
NamedTuple,
|
|
8
8
|
Optional,
|
|
9
9
|
Protocol,
|
|
10
|
+
TypedDict,
|
|
10
11
|
TypeVar,
|
|
11
12
|
overload,
|
|
12
13
|
)
|
|
@@ -27,6 +28,14 @@ FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]]
|
|
|
27
28
|
TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class ExperimentResult[T](TypedDict):
|
|
32
|
+
"""Results from a SINDy ODE experiment."""
|
|
33
|
+
|
|
34
|
+
metrics: Mapping[str, float | None]
|
|
35
|
+
data: T
|
|
36
|
+
main: object
|
|
37
|
+
|
|
38
|
+
|
|
30
39
|
class _BaseSINDy(Protocol):
|
|
31
40
|
optimizer: Any
|
|
32
41
|
feature_library: Any
|
|
@@ -62,23 +71,38 @@ class _BaseSINDy(Protocol):
|
|
|
62
71
|
self, precision: int, fmt: Literal["sympy"]
|
|
63
72
|
) -> list[dict[Expr, float]]: ...
|
|
64
73
|
|
|
74
|
+
@overload
|
|
75
|
+
def print(self, **kwargs) -> None: ...
|
|
76
|
+
|
|
77
|
+
@overload
|
|
65
78
|
def print(self, precision: int, **kwargs) -> None: ...
|
|
66
79
|
|
|
67
80
|
def get_feature_names(self) -> list[str]: ...
|
|
68
81
|
|
|
69
82
|
|
|
70
|
-
|
|
71
|
-
|
|
83
|
+
@dataclass
|
|
84
|
+
class ProbData:
|
|
85
|
+
"""Represents a single trajectory's data.
|
|
72
86
|
|
|
73
|
-
|
|
87
|
+
For measured data, only t_train, x_train, and input_features are required.
|
|
74
88
|
"""
|
|
75
89
|
|
|
76
|
-
dt: float
|
|
77
90
|
t_train: Float1D
|
|
78
91
|
x_train: Float2D
|
|
92
|
+
input_features: list[str]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class SimProbData(ProbData):
|
|
97
|
+
"""For simulated data, the noiseless trajectory is known.
|
|
98
|
+
|
|
99
|
+
Optionally includes the integrator solution object for evaluating
|
|
100
|
+
at other points.
|
|
101
|
+
"""
|
|
102
|
+
|
|
79
103
|
x_train_true: Float2D
|
|
80
104
|
x_train_true_dot: Float2D
|
|
81
|
-
|
|
105
|
+
noise_abs: float
|
|
82
106
|
integrator: Optional[Any] = None # diffrax.Solution
|
|
83
107
|
|
|
84
108
|
|
|
@@ -138,7 +162,7 @@ class NestedDict(defaultdict):
|
|
|
138
162
|
|
|
139
163
|
@dataclass
|
|
140
164
|
class DynamicsTrialData:
|
|
141
|
-
trajectories: list[
|
|
165
|
+
trajectories: list[SimProbData]
|
|
142
166
|
true_equations: list[dict[sp.Expr, float]]
|
|
143
167
|
sindy_equations: list[dict[sp.Expr, float]]
|
|
144
168
|
model: _BaseSINDy
|
|
@@ -149,7 +149,7 @@ def opt_lookup(kind):
|
|
|
149
149
|
def coeff_metrics(
|
|
150
150
|
coeff_est_dicts: list[dict[sp.Expr, float]],
|
|
151
151
|
coeff_true_dicts: list[dict[sp.Expr, float]],
|
|
152
|
-
) -> dict[str, float
|
|
152
|
+
) -> dict[str, float]:
|
|
153
153
|
"""Compute coefficient metrics from aligned coefficient dictionaries.
|
|
154
154
|
|
|
155
155
|
Both arguments are expected to be lists of coefficient dictionaries sharing
|
|
@@ -182,14 +182,18 @@ def coeff_metrics(
|
|
|
182
182
|
coefficients[row_ind, col_ind] = est_row[feat]
|
|
183
183
|
|
|
184
184
|
metrics: dict[str, float | np.floating] = {}
|
|
185
|
-
metrics["coeff_precision"] =
|
|
186
|
-
|
|
185
|
+
metrics["coeff_precision"] = float(
|
|
186
|
+
sklearn.metrics.precision_score(
|
|
187
|
+
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
188
|
+
)
|
|
187
189
|
)
|
|
188
|
-
metrics["coeff_recall"] =
|
|
189
|
-
|
|
190
|
+
metrics["coeff_recall"] = float(
|
|
191
|
+
sklearn.metrics.recall_score(
|
|
192
|
+
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
193
|
+
)
|
|
190
194
|
)
|
|
191
|
-
metrics["coeff_f1"] =
|
|
192
|
-
coeff_true.flatten() != 0, coefficients.flatten() != 0
|
|
195
|
+
metrics["coeff_f1"] = float(
|
|
196
|
+
sklearn.metrics.f1_score(coeff_true.flatten() != 0, coefficients.flatten() != 0)
|
|
193
197
|
)
|
|
194
198
|
metrics["coeff_mse"] = sklearn.metrics.mean_squared_error(
|
|
195
199
|
coeff_true.flatten(), coefficients.flatten()
|
|
@@ -198,7 +202,7 @@ def coeff_metrics(
|
|
|
198
202
|
coeff_true.flatten(), coefficients.flatten()
|
|
199
203
|
)
|
|
200
204
|
metrics["main"] = metrics["coeff_f1"]
|
|
201
|
-
return metrics
|
|
205
|
+
return {k: float(v) for k, v in metrics.items()}
|
|
202
206
|
|
|
203
207
|
|
|
204
208
|
def pred_metrics(
|
|
@@ -279,53 +283,8 @@ def unionize_coeff_dicts(
|
|
|
279
283
|
return true_aligned, est_aligned
|
|
280
284
|
|
|
281
285
|
|
|
282
|
-
def make_model(
|
|
283
|
-
input_features: list[str],
|
|
284
|
-
dt: float,
|
|
285
|
-
diff_params: dict | ps.BaseDifferentiation,
|
|
286
|
-
feat_params: dict | ps.feature_library.base.BaseFeatureLibrary,
|
|
287
|
-
opt_params: dict | ps.BaseOptimizer,
|
|
288
|
-
) -> ps.SINDy:
|
|
289
|
-
"""Build a model with object parameters dictionaries
|
|
290
|
-
|
|
291
|
-
e.g. {"kind": "finitedifference"} instead of FiniteDifference()
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
def finalize_param(lookup_func, pdict, lookup_key):
|
|
295
|
-
try:
|
|
296
|
-
cls_name = pdict.pop(lookup_key)
|
|
297
|
-
except AttributeError:
|
|
298
|
-
cls_name = pdict.vals.pop(lookup_key)
|
|
299
|
-
pdict = pdict.vals
|
|
300
|
-
|
|
301
|
-
param_cls = lookup_func(cls_name)
|
|
302
|
-
param_final = param_cls(**pdict)
|
|
303
|
-
pdict[lookup_key] = cls_name
|
|
304
|
-
return param_final
|
|
305
|
-
|
|
306
|
-
if isinstance(diff_params, ps.BaseDifferentiation):
|
|
307
|
-
diff = diff_params
|
|
308
|
-
else:
|
|
309
|
-
diff = finalize_param(diff_lookup, diff_params, "diffcls")
|
|
310
|
-
if isinstance(feat_params, ps.feature_library.base.BaseFeatureLibrary):
|
|
311
|
-
features = feat_params
|
|
312
|
-
else:
|
|
313
|
-
features = finalize_param(feature_lookup, feat_params, "featcls")
|
|
314
|
-
if isinstance(opt_params, ps.BaseOptimizer):
|
|
315
|
-
opt = opt_params
|
|
316
|
-
else:
|
|
317
|
-
opt = finalize_param(opt_lookup, opt_params, "optcls")
|
|
318
|
-
return ps.SINDy(
|
|
319
|
-
differentiation_method=diff,
|
|
320
|
-
optimizer=opt,
|
|
321
|
-
t_default=dt, # type: ignore
|
|
322
|
-
feature_library=features,
|
|
323
|
-
feature_names=input_features,
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
|
|
327
286
|
def _simulate_test_data(
|
|
328
|
-
model: _BaseSINDy,
|
|
287
|
+
model: _BaseSINDy, t_test: Float1D, x_test: Float2D
|
|
329
288
|
) -> SINDyTrialUpdate:
|
|
330
289
|
"""Add simulation data to grid_data
|
|
331
290
|
|
|
@@ -333,7 +292,6 @@ def _simulate_test_data(
|
|
|
333
292
|
Returns:
|
|
334
293
|
Complete GridPointData
|
|
335
294
|
"""
|
|
336
|
-
t_test = cast(Float1D, np.arange(0, len(x_test) * dt, step=dt))
|
|
337
295
|
t_sim = t_test
|
|
338
296
|
try:
|
|
339
297
|
|
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sindy-exp
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A basic library for constructing dynamics experiments
|
|
5
5
|
Author-email: Jake Stevens-Haas <jacob.stevens.haas@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -34,7 +34,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
34
34
|
Classifier: License :: OSI Approved :: MIT License
|
|
35
35
|
Classifier: Natural Language :: English
|
|
36
36
|
Classifier: Operating System :: POSIX :: Linux
|
|
37
|
-
Requires-Python: >=3.
|
|
37
|
+
Requires-Python: >=3.12
|
|
38
38
|
Description-Content-Type: text/markdown
|
|
39
39
|
License-File: LICENSE
|
|
40
40
|
Requires-Dist: matplotlib
|
|
@@ -43,6 +43,7 @@ Requires-Dist: seaborn
|
|
|
43
43
|
Requires-Dist: scipy
|
|
44
44
|
Requires-Dist: sympy
|
|
45
45
|
Requires-Dist: dysts
|
|
46
|
+
Requires-Dist: scikit-learn
|
|
46
47
|
Provides-Extra: jax
|
|
47
48
|
Requires-Dist: jax[cuda12]; extra == "jax"
|
|
48
49
|
Requires-Dist: diffrax; extra == "jax"
|
|
@@ -64,14 +65,20 @@ Requires-Dist: tomli; extra == "dev"
|
|
|
64
65
|
Requires-Dist: pysindy>=2.1.0; extra == "dev"
|
|
65
66
|
Dynamic: license-file
|
|
66
67
|
|
|
67
|
-
#
|
|
68
|
+
# Overview
|
|
68
69
|
|
|
69
|
-
A library for constructing dynamics experiments.
|
|
70
|
-
This includes data generation and
|
|
70
|
+
A library for constructing dynamics experiments from the dynamics models in the `dysts` package.
|
|
71
|
+
This includes data generation and model evaluation.
|
|
72
|
+
The first contribution is the static typing of trajectory data (`ProbData`) that, I believe, provides the necessary information to be useful in evaluating a wide variety of dynamics/time-series learning methods.
|
|
73
|
+
The second contribution is the collection of utility functions for designing dynamics learning experiments.
|
|
74
|
+
The third contribution is the collection of such experiments for evaluating dynamics/time-series learning models that meet the `BaseSINDy` API.
|
|
75
|
+
|
|
76
|
+
It aims to (a) be amenable to both `numpy` and `jax` arrays, (b) be usable by any dynamics/time-series learning models that meet the `BaseSINDy` or scikit-time API.
|
|
77
|
+
Internally, this package is used/will be used in benchmarking pysindy runtime/memory usage and choosing default hyperparameters.
|
|
71
78
|
|
|
72
79
|
## Getting started
|
|
73
80
|
|
|
74
|
-
|
|
81
|
+
Install with `pip install sindy-exp` or `pip install sindy-exp[jax]`.
|
|
75
82
|
|
|
76
83
|
Generate data
|
|
77
84
|
|
|
@@ -86,26 +93,26 @@ Evaluate your SINDy-like model with:
|
|
|
86
93
|
A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
|
|
87
94
|
of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
|
|
88
95
|
|
|
89
|
-
## ODE
|
|
96
|
+
## ODE & Data Model
|
|
97
|
+
|
|
98
|
+
Generated or measured data has the dataclass type `ProbData` or `SimProbData`, respectively,
|
|
99
|
+
to indicate whether it includes ground truth information and a noise level.
|
|
100
|
+
If the data is generated in jax, it will have an integrator that can later be used to evaluate the true data on collocation points.
|
|
90
101
|
|
|
91
102
|
We deal primarily with autonomous ODE systems of the form:
|
|
92
103
|
|
|
93
104
|
dx/dt = sum_i f_i(x)
|
|
94
105
|
|
|
95
|
-
|
|
106
|
+
We represent ODE systems as a list of right-hand side expressions.
|
|
96
107
|
Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
|
|
108
|
+
Thus, the rhs of an ODE is of type: `list[dict[sympy.Expr, float]]`
|
|
97
109
|
|
|
98
110
|
## Other useful imports, compatibility, and extensions
|
|
99
111
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
To integrate your own experiments or data generation in a way that is compatible,
|
|
105
|
-
see the `ProbData` and `DynamicsTrialData` classes.
|
|
106
|
-
For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`,
|
|
107
|
-
`plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
108
|
-
For metrics, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
112
|
+
* The experiments are built to be compatible with the `mitosis` tool, an experiment runner. Mitosis is not a dependency, however, to allow using other experiment runners.
|
|
113
|
+
* To integrate your own experiments or data generation in a way that is compatible, see the `ProbData`, `SimProbData`, `DynamicsTrialData`, and `FullDynamicsTrialData` classes.
|
|
114
|
+
* For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`, `plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
115
|
+
* For evaluation of models, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
109
116
|
|
|
110
117
|

|
|
111
118
|

|
|
@@ -7,6 +7,9 @@ pyproject.toml
|
|
|
7
7
|
setup.cfg
|
|
8
8
|
.github/workflows/main.yaml
|
|
9
9
|
.github/workflows/release.yml
|
|
10
|
+
images/1d.png
|
|
11
|
+
images/coeff.png
|
|
12
|
+
images/composite.png
|
|
10
13
|
src/sindy_exp/__init__.py
|
|
11
14
|
src/sindy_exp/_data.py
|
|
12
15
|
src/sindy_exp/_diffrax_solver.py
|
|
@@ -16,6 +19,7 @@ src/sindy_exp/_plotting.py
|
|
|
16
19
|
src/sindy_exp/_typing.py
|
|
17
20
|
src/sindy_exp/_utils.py
|
|
18
21
|
src/sindy_exp/addl_attractors.json
|
|
22
|
+
src/sindy_exp/py.typed
|
|
19
23
|
src/sindy_exp.egg-info/PKG-INFO
|
|
20
24
|
src/sindy_exp.egg-info/SOURCES.txt
|
|
21
25
|
src/sindy_exp.egg-info/dependency_links.txt
|
|
@@ -180,7 +180,7 @@ def test_gen_data(rhs_name, array_namespace, jax_cpu_only):
|
|
|
180
180
|
result = gen_data(
|
|
181
181
|
rhs_name, t_end=0.1, noise_abs=0.01, seed=42, array_namespace=array_namespace
|
|
182
182
|
)["data"]
|
|
183
|
-
trajectories = result[
|
|
183
|
+
trajectories = result[0]
|
|
184
184
|
assert len(trajectories) == 1
|
|
185
185
|
traj = trajectories[0]
|
|
186
186
|
assert traj.x_train.shape == traj.x_train_true_dot.shape
|
sindy_exp-0.2.1/README.md
DELETED
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
# Dynamics Experiments
|
|
2
|
-
|
|
3
|
-
A library for constructing dynamics experiments.
|
|
4
|
-
This includes data generation and plotting/evaluation.
|
|
5
|
-
|
|
6
|
-
## Getting started
|
|
7
|
-
|
|
8
|
-
It's not yet on PyPI, so install it with `pip install sindy_exp @ git+https://github.com/Jacob-Stevens-Haas/sindy-experiments`
|
|
9
|
-
|
|
10
|
-
Generate data
|
|
11
|
-
|
|
12
|
-
data = sindy_exp.data.gen_data("lorenz", num_trajectories=5, t_end=10.0, dt=0.01)["data]
|
|
13
|
-
|
|
14
|
-
Evaluate your SINDy-like model with:
|
|
15
|
-
|
|
16
|
-
sindy_exp.odes.fit_eval(model, data)
|
|
17
|
-
|
|
18
|
-

|
|
19
|
-
|
|
20
|
-
A list of available ODE systems can be found in `ODE_CLASSES`, which includes most
|
|
21
|
-
of the systems from the [dysts package](https://pypi.org/project/dysts/) as well as some non-chaotic systems.
|
|
22
|
-
|
|
23
|
-
## ODE representation
|
|
24
|
-
|
|
25
|
-
We deal primarily with autonomous ODE systems of the form:
|
|
26
|
-
|
|
27
|
-
dx/dt = sum_i f_i(x)
|
|
28
|
-
|
|
29
|
-
Thus, we represent ODE systems as a list of right-hand side expressions.
|
|
30
|
-
Each element is a dictionary mapping a term (Sympy expression) to its coefficient.
|
|
31
|
-
|
|
32
|
-
## Other useful imports, compatibility, and extensions
|
|
33
|
-
|
|
34
|
-
This is built to be compatible with dynamics learning models that follow the
|
|
35
|
-
pysindy _BaseSINDy interface.
|
|
36
|
-
The experiments are also built to be compatible with the `mitosis` tool,
|
|
37
|
-
an experiment runner.
|
|
38
|
-
To integrate your own experiments or data generation in a way that is compatible,
|
|
39
|
-
see the `ProbData` and `DynamicsTrialData` classes.
|
|
40
|
-
For plotting tools, see `plot_coefficients`, `compare_coefficient_plots_from_dicts`,
|
|
41
|
-
`plot_test_trajectory`, `plot_training_data`, and `COLOR`.
|
|
42
|
-
For metrics, see `coeff_metrics`, `pred_metrics`, and `integration_metrics`.
|
|
43
|
-
|
|
44
|
-

|
|
45
|
-

|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|