cd-dynamax 0.3.2__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/PKG-INFO +46 -27
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/README.md +44 -26
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +7 -8
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +14 -20
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +4 -4
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +6 -6
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +15 -15
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +4 -4
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/PKG-INFO +46 -27
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/SOURCES.txt +1 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/requires.txt +3 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/pyproject.toml +3 -2
- cd_dynamax-0.3.3/tests/test_state_dependent_diffusion.py +93 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/LICENSE +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/_version.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/initial.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/parameters.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/parameters_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/ssm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/types.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/bijectors.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/distributions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/distributions_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/optimize.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/plotting.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/utils_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/warnings.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/ssm_temissions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/data_driven_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/data_generator.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/debug_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/demo_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/diffrax_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/evaluation_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/experiment_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/likelihood_eval_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/optimize_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/physics_based_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/plotting_chaos_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/plotting_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/prior_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/simulation_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/test_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/dependency_links.txt +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/top_level.txt +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/setup.cfg +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_cdlgssm_dlgssm_match.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_cdnonlinear_cdlinear_match.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_filter_forecast_emissions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_imports.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_utils_imports.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cd_dynamax
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Continuous-discrete dynamical systems with JAX and related libraries.
|
|
5
5
|
Author: Matthew Levine, Iñigo Urteaga
|
|
6
6
|
Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
|
|
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
|
|
|
32
32
|
Requires-Dist: fastprogress>=1.0.0
|
|
33
33
|
Requires-Dist: graphviz
|
|
34
34
|
Requires-Dist: ipykernel
|
|
35
|
+
Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
|
|
35
36
|
Provides-Extra: dev
|
|
36
37
|
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
37
38
|
Requires-Dist: ruff; extra == "dev"
|
|
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
|
|
|
69
70
|
|
|
70
71
|
where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
|
|
71
72
|
|
|
72
|
-
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\
|
|
73
|
+
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
|
|
73
74
|
|
|
74
75
|
Note:
|
|
75
76
|
|
|
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
|
|
|
90
91
|
|
|
91
92
|
All of these problems are deeply interconnected.
|
|
92
93
|
|
|
93
|
-
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]
|
|
94
|
+
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
|
|
94
95
|
|
|
95
96
|
- In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
|
|
96
97
|
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
|
|
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
|
|
|
113
114
|
|
|
114
115
|
- The codebase is organized into several key directories:
|
|
115
116
|
```
|
|
116
|
-
|
|
117
|
-
├──
|
|
118
|
-
│ ├──
|
|
119
|
-
│ ├──
|
|
120
|
-
│ ├──
|
|
121
|
-
│
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
├──
|
|
125
|
-
├── python/
|
|
126
|
-
├── python/
|
|
127
|
-
|
|
117
|
+
.
|
|
118
|
+
├── cd_dynamax/ # Source code for cd-dynamax library
|
|
119
|
+
│ ├── src/ # Core source code
|
|
120
|
+
│ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
|
|
121
|
+
│ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
|
|
122
|
+
│ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
|
|
123
|
+
│ │ └── utils/ # Utility functions and example models
|
|
124
|
+
│ └── dynamax/ # Original dynamax library (as a submodule)
|
|
125
|
+
├── demos/ # Python demos showcasing cd-dynamax functionality
|
|
126
|
+
│ ├── python/scripts/ # Python scripts for running demos
|
|
127
|
+
│ ├── python/notebooks/ # Jupyter notebooks for interactive demos
|
|
128
|
+
│ └── python/configs/ # Configuration files for demos
|
|
129
|
+
└── tests/ # Tests for cd-dynamax functionality
|
|
128
130
|
```
|
|
129
131
|
|
|
130
132
|
## [Demos](./demos/python)
|
|
@@ -155,7 +157,7 @@ make test
|
|
|
155
157
|
```
|
|
156
158
|
|
|
157
159
|
- For linting, we use `ruff`:
|
|
158
|
-
```
|
|
160
|
+
```bash
|
|
159
161
|
make lint
|
|
160
162
|
```
|
|
161
163
|
|
|
@@ -171,40 +173,57 @@ make build_docs
|
|
|
171
173
|
|
|
172
174
|
# Installation
|
|
173
175
|
|
|
174
|
-
|
|
176
|
+
Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
|
|
175
177
|
|
|
176
178
|
---
|
|
177
179
|
|
|
178
|
-
### Option 1:
|
|
180
|
+
### Option 1: Install from PyPI (recommended)
|
|
179
181
|
|
|
180
182
|
```bash
|
|
181
|
-
# Create and activate a
|
|
182
|
-
|
|
183
|
-
|
|
183
|
+
# Create and activate a virtual environment
|
|
184
|
+
python -m venv .venv # or `uv venv`
|
|
185
|
+
source .venv/bin/activate # on macOS/Linux
|
|
186
|
+
.venv\Scripts\activate # on Windows
|
|
184
187
|
|
|
185
|
-
#
|
|
186
|
-
pip install
|
|
188
|
+
# Upgrade pip
|
|
189
|
+
pip install --upgrade pip
|
|
190
|
+
|
|
191
|
+
# Install latest release from PyPI
|
|
192
|
+
pip install cd-dynamax
|
|
187
193
|
```
|
|
188
194
|
|
|
189
|
-
|
|
195
|
+
`cd-dynamax` is currently **not available on Conda Forge**.
|
|
190
196
|
|
|
191
197
|
---
|
|
192
198
|
|
|
193
|
-
### Option 2:
|
|
199
|
+
### Option 2: Install from source (editable)
|
|
194
200
|
|
|
195
201
|
```bash
|
|
196
202
|
# Create and activate a virtual environment
|
|
197
|
-
python -m venv .venv
|
|
203
|
+
python -m venv .venv # or `uv venv`
|
|
198
204
|
source .venv/bin/activate # on macOS/Linux
|
|
199
205
|
.venv\Scripts\activate # on Windows
|
|
200
206
|
|
|
201
207
|
# Upgrade pip
|
|
202
208
|
pip install --upgrade pip
|
|
203
209
|
|
|
204
|
-
# Install in editable mode
|
|
210
|
+
# Install in editable mode for local development
|
|
205
211
|
pip install -e .[dev]
|
|
206
212
|
```
|
|
207
213
|
|
|
214
|
+
---
|
|
215
|
+
|
|
216
|
+
### Option 3: Conda environment + pip install
|
|
217
|
+
|
|
218
|
+
```bash
|
|
219
|
+
# Create and activate a Conda environment with Python 3.11
|
|
220
|
+
conda create -n cd_dynamax python=3.11
|
|
221
|
+
conda activate cd_dynamax
|
|
222
|
+
|
|
223
|
+
# Install latest release from PyPI
|
|
224
|
+
pip install cd-dynamax
|
|
225
|
+
```
|
|
226
|
+
|
|
208
227
|
#### GPU support
|
|
209
228
|
If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
|
|
210
229
|
|
|
@@ -25,7 +25,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
|
|
|
25
25
|
|
|
26
26
|
where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
|
|
27
27
|
|
|
28
|
-
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\
|
|
28
|
+
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
|
|
29
29
|
|
|
30
30
|
Note:
|
|
31
31
|
|
|
@@ -46,7 +46,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
|
|
|
46
46
|
|
|
47
47
|
All of these problems are deeply interconnected.
|
|
48
48
|
|
|
49
|
-
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]
|
|
49
|
+
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
|
|
50
50
|
|
|
51
51
|
- In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
|
|
52
52
|
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
|
|
@@ -69,18 +69,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
|
|
|
69
69
|
|
|
70
70
|
- The codebase is organized into several key directories:
|
|
71
71
|
```
|
|
72
|
-
|
|
73
|
-
├──
|
|
74
|
-
│ ├──
|
|
75
|
-
│ ├──
|
|
76
|
-
│ ├──
|
|
77
|
-
│
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
├──
|
|
81
|
-
├── python/
|
|
82
|
-
├── python/
|
|
83
|
-
|
|
72
|
+
.
|
|
73
|
+
├── cd_dynamax/ # Source code for cd-dynamax library
|
|
74
|
+
│ ├── src/ # Core source code
|
|
75
|
+
│ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
|
|
76
|
+
│ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
|
|
77
|
+
│ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
|
|
78
|
+
│ │ └── utils/ # Utility functions and example models
|
|
79
|
+
│ └── dynamax/ # Original dynamax library (as a submodule)
|
|
80
|
+
├── demos/ # Python demos showcasing cd-dynamax functionality
|
|
81
|
+
│ ├── python/scripts/ # Python scripts for running demos
|
|
82
|
+
│ ├── python/notebooks/ # Jupyter notebooks for interactive demos
|
|
83
|
+
│ └── python/configs/ # Configuration files for demos
|
|
84
|
+
└── tests/ # Tests for cd-dynamax functionality
|
|
84
85
|
```
|
|
85
86
|
|
|
86
87
|
## [Demos](./demos/python)
|
|
@@ -111,7 +112,7 @@ make test
|
|
|
111
112
|
```
|
|
112
113
|
|
|
113
114
|
- For linting, we use `ruff`:
|
|
114
|
-
```
|
|
115
|
+
```bash
|
|
115
116
|
make lint
|
|
116
117
|
```
|
|
117
118
|
|
|
@@ -127,40 +128,57 @@ make build_docs
|
|
|
127
128
|
|
|
128
129
|
# Installation
|
|
129
130
|
|
|
130
|
-
|
|
131
|
+
Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
|
|
131
132
|
|
|
132
133
|
---
|
|
133
134
|
|
|
134
|
-
### Option 1:
|
|
135
|
+
### Option 1: Install from PyPI (recommended)
|
|
135
136
|
|
|
136
137
|
```bash
|
|
137
|
-
# Create and activate a
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
# Create and activate a virtual environment
|
|
139
|
+
python -m venv .venv # or `uv venv`
|
|
140
|
+
source .venv/bin/activate # on macOS/Linux
|
|
141
|
+
.venv\Scripts\activate # on Windows
|
|
140
142
|
|
|
141
|
-
#
|
|
142
|
-
pip install
|
|
143
|
+
# Upgrade pip
|
|
144
|
+
pip install --upgrade pip
|
|
145
|
+
|
|
146
|
+
# Install latest release from PyPI
|
|
147
|
+
pip install cd-dynamax
|
|
143
148
|
```
|
|
144
149
|
|
|
145
|
-
|
|
150
|
+
`cd-dynamax` is currently **not available on Conda Forge**.
|
|
146
151
|
|
|
147
152
|
---
|
|
148
153
|
|
|
149
|
-
### Option 2:
|
|
154
|
+
### Option 2: Install from source (editable)
|
|
150
155
|
|
|
151
156
|
```bash
|
|
152
157
|
# Create and activate a virtual environment
|
|
153
|
-
python -m venv .venv
|
|
158
|
+
python -m venv .venv # or `uv venv`
|
|
154
159
|
source .venv/bin/activate # on macOS/Linux
|
|
155
160
|
.venv\Scripts\activate # on Windows
|
|
156
161
|
|
|
157
162
|
# Upgrade pip
|
|
158
163
|
pip install --upgrade pip
|
|
159
164
|
|
|
160
|
-
# Install in editable mode
|
|
165
|
+
# Install in editable mode for local development
|
|
161
166
|
pip install -e .[dev]
|
|
162
167
|
```
|
|
163
168
|
|
|
169
|
+
---
|
|
170
|
+
|
|
171
|
+
### Option 3: Conda environment + pip install
|
|
172
|
+
|
|
173
|
+
```bash
|
|
174
|
+
# Create and activate a Conda environment with Python 3.11
|
|
175
|
+
conda create -n cd_dynamax python=3.11
|
|
176
|
+
conda activate cd_dynamax
|
|
177
|
+
|
|
178
|
+
# Install latest release from PyPI
|
|
179
|
+
pip install cd-dynamax
|
|
180
|
+
```
|
|
181
|
+
|
|
164
182
|
#### GPU support
|
|
165
183
|
If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
|
|
166
184
|
|
|
@@ -119,9 +119,9 @@ def _predict(
|
|
|
119
119
|
f = params.dynamics.drift.f
|
|
120
120
|
|
|
121
121
|
# Get time-varying parameters
|
|
122
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
122
|
+
Qc_t = params.dynamics.diffusion_cov.f(m, u, t)
|
|
123
123
|
L_t = (
|
|
124
|
-
params.dynamics.diffusion_coefficient.f(
|
|
124
|
+
params.dynamics.diffusion_coefficient.f(m, u, t)
|
|
125
125
|
* filter_hyperparams.cov_rescaling
|
|
126
126
|
)
|
|
127
127
|
|
|
@@ -185,9 +185,9 @@ def _predict(
|
|
|
185
185
|
dt = filter_hyperparams.dt_average
|
|
186
186
|
|
|
187
187
|
# Covariance parameters at time t0
|
|
188
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
188
|
+
Qc_t = params.dynamics.diffusion_cov.f(m, u, t0)
|
|
189
189
|
L_t = (
|
|
190
|
-
params.dynamics.diffusion_coefficient.f(
|
|
190
|
+
params.dynamics.diffusion_coefficient.f(m, u, t0)
|
|
191
191
|
* filter_hyperparams.cov_rescaling
|
|
192
192
|
)
|
|
193
193
|
# Covariance update
|
|
@@ -527,12 +527,11 @@ def _smooth(
|
|
|
527
527
|
m_smooth, P_smooth = y
|
|
528
528
|
m_filter, P_filter = args
|
|
529
529
|
|
|
530
|
-
# TODO: possibly time- and parameter-dependent functions
|
|
531
530
|
f = params.dynamics.drift.f
|
|
532
531
|
# Get time-varying parameters
|
|
533
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
532
|
+
Qc_t = params.dynamics.diffusion_cov.f(m_filter, u, t)
|
|
534
533
|
L_t = (
|
|
535
|
-
params.dynamics.diffusion_coefficient.f(
|
|
534
|
+
params.dynamics.diffusion_coefficient.f(m_filter, u, t)
|
|
536
535
|
* filter_hyperparams.cov_rescaling
|
|
537
536
|
)
|
|
538
537
|
|
|
@@ -831,7 +830,7 @@ def extended_kalman_posterior_sample(
|
|
|
831
830
|
|
|
832
831
|
# Get parameters and inputs for time t0
|
|
833
832
|
u = inputs[t0_idx]
|
|
834
|
-
Q = params.dynamics.diffusion_cov.f(
|
|
833
|
+
Q = params.dynamics.diffusion_cov.f(filtered_mean, u, t0)
|
|
835
834
|
|
|
836
835
|
# Condition on next state
|
|
837
836
|
smoothed_mean, smoothed_cov = _condition_on(
|
|
@@ -113,9 +113,9 @@ def _predict(
|
|
|
113
113
|
# First order EnKF diffusion
|
|
114
114
|
def diffusion(t, y, args):
|
|
115
115
|
# Get parameters at time t and input u
|
|
116
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
116
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, u, t)
|
|
117
117
|
L_t = (
|
|
118
|
-
params.dynamics.diffusion_coefficient.f(
|
|
118
|
+
params.dynamics.diffusion_coefficient.f(y, u, t)
|
|
119
119
|
* filter_hyperparams.cov_rescaling
|
|
120
120
|
)
|
|
121
121
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
@@ -155,26 +155,20 @@ def _predict(
|
|
|
155
155
|
# but the same amount of noise is added after each measurement.
|
|
156
156
|
dt = filter_hyperparams.dt_average
|
|
157
157
|
|
|
158
|
-
|
|
159
|
-
Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
|
|
160
|
-
L_t = (
|
|
161
|
-
params.dynamics.diffusion_coefficient.f(None, u, t0)
|
|
162
|
-
* filter_hyperparams.cov_rescaling
|
|
163
|
-
)
|
|
158
|
+
key_array = jr.split(key, x.shape[0])
|
|
164
159
|
|
|
165
|
-
|
|
166
|
-
|
|
160
|
+
def _sample_noise(x_i, key_i):
|
|
161
|
+
Qc_t = params.dynamics.diffusion_cov.f(x_i, u, t0)
|
|
162
|
+
L_t = (
|
|
163
|
+
params.dynamics.diffusion_coefficient.f(x_i, u, t0)
|
|
164
|
+
* filter_hyperparams.cov_rescaling
|
|
165
|
+
)
|
|
166
|
+
state_noise_cov = dt * L_t @ Qc_t @ L_t.T # D_hid x D_hid
|
|
167
|
+
return jr.multivariate_normal(
|
|
168
|
+
key=key_i, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
|
|
169
|
+
)
|
|
167
170
|
|
|
168
|
-
|
|
169
|
-
# Split keys for each particle
|
|
170
|
-
key_array = jr.split(key, x.shape[0])
|
|
171
|
-
# vmap over particles
|
|
172
|
-
noise = vmap(
|
|
173
|
-
lambda key: jr.multivariate_normal(
|
|
174
|
-
key=key, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
|
|
175
|
-
),
|
|
176
|
-
in_axes=0,
|
|
177
|
-
)(key_array)
|
|
171
|
+
noise = vmap(_sample_noise, in_axes=(0, 0))(x_pred, key_array)
|
|
178
172
|
|
|
179
173
|
# Add noise to predicted particles
|
|
180
174
|
x_pred += noise
|
|
@@ -214,9 +214,9 @@ def _predict(
|
|
|
214
214
|
dt = filter_hyperparams.dt_average
|
|
215
215
|
|
|
216
216
|
# Get diffusion parameters at time t0 and input u
|
|
217
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
217
|
+
Qc_t = params.dynamics.diffusion_cov.f(m, u, t0)
|
|
218
218
|
L_t = (
|
|
219
|
-
params.dynamics.diffusion_coefficient.f(
|
|
219
|
+
params.dynamics.diffusion_coefficient.f(m, u, t0)
|
|
220
220
|
* filter_hyperparams.cov_rescaling
|
|
221
221
|
)
|
|
222
222
|
# Compute state noise covariance
|
|
@@ -237,9 +237,9 @@ def _predict(
|
|
|
237
237
|
f = (
|
|
238
238
|
params.dynamics.drift.f
|
|
239
239
|
) # TODO: reconsider when we want time-varying dynamics functions
|
|
240
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
240
|
+
Qc_t = params.dynamics.diffusion_cov.f(m_t, u, t)
|
|
241
241
|
L_t = (
|
|
242
|
-
params.dynamics.diffusion_coefficient.f(
|
|
242
|
+
params.dynamics.diffusion_coefficient.f(m_t, u, t)
|
|
243
243
|
* filter_hyperparams.cov_rescaling
|
|
244
244
|
)
|
|
245
245
|
|
|
@@ -111,8 +111,8 @@ def compute_pushforward(
|
|
|
111
111
|
f = params.dynamics.drift.f
|
|
112
112
|
|
|
113
113
|
# Get time-varying parameters
|
|
114
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
115
|
-
L_t = params.dynamics.diffusion_coefficient.f(
|
|
114
|
+
Qc_t = params.dynamics.diffusion_cov.f(x, inputs, t)
|
|
115
|
+
L_t = params.dynamics.diffusion_coefficient.f(x, inputs, t)
|
|
116
116
|
|
|
117
117
|
# Different SDE approximations to the dynamics
|
|
118
118
|
# Zeroth-order (no gradient information),
|
|
@@ -964,8 +964,8 @@ def cdnlgssm_path_sample(
|
|
|
964
964
|
return params.dynamics.drift.f(y, inpt, t)
|
|
965
965
|
|
|
966
966
|
def diffusion(t, y, args):
|
|
967
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
968
|
-
L_t = params.dynamics.diffusion_coefficient.f(
|
|
967
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, inpt, t)
|
|
968
|
+
L_t = params.dynamics.diffusion_coefficient.f(y, inpt, t)
|
|
969
969
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
970
970
|
combined_diffusion = L_t @ Q_sqrt
|
|
971
971
|
return combined_diffusion
|
|
@@ -1261,9 +1261,9 @@ def cdnlgssm_forecast(
|
|
|
1261
1261
|
return params.dynamics.drift.f(y, inputs[t0_idx], t)
|
|
1262
1262
|
|
|
1263
1263
|
def diffusion(t, y, args):
|
|
1264
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
1264
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, inputs[t0_idx], t)
|
|
1265
1265
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
1266
|
-
L_t = params.dynamics.diffusion_coefficient.f(
|
|
1266
|
+
L_t = params.dynamics.diffusion_coefficient.f(y, inputs[t0_idx], t)
|
|
1267
1267
|
combined_diffusion = L_t @ Q_sqrt
|
|
1268
1268
|
return combined_diffusion
|
|
1269
1269
|
|
|
@@ -112,9 +112,9 @@ def _predict(
|
|
|
112
112
|
else:
|
|
113
113
|
|
|
114
114
|
def diffusion(t, y, args):
|
|
115
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
115
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, u, t)
|
|
116
116
|
L_t = (
|
|
117
|
-
params.dynamics.diffusion_coefficient.f(
|
|
117
|
+
params.dynamics.diffusion_coefficient.f(y, u, t)
|
|
118
118
|
* filter_hyperparams.cov_rescaling
|
|
119
119
|
)
|
|
120
120
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
@@ -143,20 +143,20 @@ def _predict(
|
|
|
143
143
|
else:
|
|
144
144
|
dt = filter_hyperparams.dt_average
|
|
145
145
|
|
|
146
|
-
Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
|
|
147
|
-
L_t = (
|
|
148
|
-
params.dynamics.diffusion_coefficient.f(None, u, t0)
|
|
149
|
-
* filter_hyperparams.cov_rescaling
|
|
150
|
-
)
|
|
151
|
-
state_noise_cov = dt * L_t @ Qc_t @ L_t.T # D_hid x D_hid
|
|
152
|
-
|
|
153
146
|
key_array = jr.split(key, x.shape[0])
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
147
|
+
|
|
148
|
+
def _sample_noise(x_i, key_i):
|
|
149
|
+
Qc_t = params.dynamics.diffusion_cov.f(x_i, u, t0)
|
|
150
|
+
L_t = (
|
|
151
|
+
params.dynamics.diffusion_coefficient.f(x_i, u, t0)
|
|
152
|
+
* filter_hyperparams.cov_rescaling
|
|
153
|
+
)
|
|
154
|
+
state_noise_cov = dt * L_t @ Qc_t @ L_t.T
|
|
155
|
+
return jr.multivariate_normal(
|
|
156
|
+
key=key_i, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
noise = vmap(_sample_noise, in_axes=(0, 0))(x_pred, key_array)
|
|
160
160
|
x_pred += noise
|
|
161
161
|
return x_pred
|
|
162
162
|
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py
RENAMED
|
@@ -387,8 +387,8 @@ class ContDiscreteNonlinearSSM(SSM):
|
|
|
387
387
|
return params.dynamics.drift.f(y, u_prev_t, t)
|
|
388
388
|
|
|
389
389
|
def diffusion(t, y, _):
|
|
390
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
391
|
-
L_t = params.dynamics.diffusion_coefficient.f(
|
|
390
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, u_prev_t, t)
|
|
391
|
+
L_t = params.dynamics.diffusion_coefficient.f(y, u_prev_t, t)
|
|
392
392
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
393
393
|
return L_t @ Q_sqrt
|
|
394
394
|
|
|
@@ -818,10 +818,10 @@ def cdnlssm_forecast(
|
|
|
818
818
|
return params.dynamics.drift.f(y, forecast_inputs[t0_idx], t)
|
|
819
819
|
|
|
820
820
|
def diffusion(t, y, args):
|
|
821
|
-
Qc_t = params.dynamics.diffusion_cov.f(
|
|
821
|
+
Qc_t = params.dynamics.diffusion_cov.f(y, forecast_inputs[t0_idx], t)
|
|
822
822
|
Q_sqrt = jnp.linalg.cholesky(Qc_t)
|
|
823
823
|
L_t = params.dynamics.diffusion_coefficient.f(
|
|
824
|
-
|
|
824
|
+
y, forecast_inputs[t0_idx], t
|
|
825
825
|
)
|
|
826
826
|
combined_diffusion = L_t @ Q_sqrt
|
|
827
827
|
return combined_diffusion
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cd_dynamax
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Continuous-discrete dynamical systems with JAX and related libraries.
|
|
5
5
|
Author: Matthew Levine, Iñigo Urteaga
|
|
6
6
|
Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
|
|
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
|
|
|
32
32
|
Requires-Dist: fastprogress>=1.0.0
|
|
33
33
|
Requires-Dist: graphviz
|
|
34
34
|
Requires-Dist: ipykernel
|
|
35
|
+
Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
|
|
35
36
|
Provides-Extra: dev
|
|
36
37
|
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
37
38
|
Requires-Dist: ruff; extra == "dev"
|
|
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
|
|
|
69
70
|
|
|
70
71
|
where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
|
|
71
72
|
|
|
72
|
-
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\
|
|
73
|
+
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
|
|
73
74
|
|
|
74
75
|
Note:
|
|
75
76
|
|
|
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
|
|
|
90
91
|
|
|
91
92
|
All of these problems are deeply interconnected.
|
|
92
93
|
|
|
93
|
-
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]
|
|
94
|
+
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
|
|
94
95
|
|
|
95
96
|
- In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
|
|
96
97
|
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
|
|
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
|
|
|
113
114
|
|
|
114
115
|
- The codebase is organized into several key directories:
|
|
115
116
|
```
|
|
116
|
-
|
|
117
|
-
├──
|
|
118
|
-
│ ├──
|
|
119
|
-
│ ├──
|
|
120
|
-
│ ├──
|
|
121
|
-
│
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
├──
|
|
125
|
-
├── python/
|
|
126
|
-
├── python/
|
|
127
|
-
|
|
117
|
+
.
|
|
118
|
+
├── cd_dynamax/ # Source code for cd-dynamax library
|
|
119
|
+
│ ├── src/ # Core source code
|
|
120
|
+
│ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
|
|
121
|
+
│ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
|
|
122
|
+
│ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
|
|
123
|
+
│ │ └── utils/ # Utility functions and example models
|
|
124
|
+
│ └── dynamax/ # Original dynamax library (as a submodule)
|
|
125
|
+
├── demos/ # Python demos showcasing cd-dynamax functionality
|
|
126
|
+
│ ├── python/scripts/ # Python scripts for running demos
|
|
127
|
+
│ ├── python/notebooks/ # Jupyter notebooks for interactive demos
|
|
128
|
+
│ └── python/configs/ # Configuration files for demos
|
|
129
|
+
└── tests/ # Tests for cd-dynamax functionality
|
|
128
130
|
```
|
|
129
131
|
|
|
130
132
|
## [Demos](./demos/python)
|
|
@@ -155,7 +157,7 @@ make test
|
|
|
155
157
|
```
|
|
156
158
|
|
|
157
159
|
- For linting, we use `ruff`:
|
|
158
|
-
```
|
|
160
|
+
```bash
|
|
159
161
|
make lint
|
|
160
162
|
```
|
|
161
163
|
|
|
@@ -171,40 +173,57 @@ make build_docs
|
|
|
171
173
|
|
|
172
174
|
# Installation
|
|
173
175
|
|
|
174
|
-
|
|
176
|
+
Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
|
|
175
177
|
|
|
176
178
|
---
|
|
177
179
|
|
|
178
|
-
### Option 1:
|
|
180
|
+
### Option 1: Install from PyPI (recommended)
|
|
179
181
|
|
|
180
182
|
```bash
|
|
181
|
-
# Create and activate a
|
|
182
|
-
|
|
183
|
-
|
|
183
|
+
# Create and activate a virtual environment
|
|
184
|
+
python -m venv .venv # or `uv venv`
|
|
185
|
+
source .venv/bin/activate # on macOS/Linux
|
|
186
|
+
.venv\Scripts\activate # on Windows
|
|
184
187
|
|
|
185
|
-
#
|
|
186
|
-
pip install
|
|
188
|
+
# Upgrade pip
|
|
189
|
+
pip install --upgrade pip
|
|
190
|
+
|
|
191
|
+
# Install latest release from PyPI
|
|
192
|
+
pip install cd-dynamax
|
|
187
193
|
```
|
|
188
194
|
|
|
189
|
-
|
|
195
|
+
`cd-dynamax` is currently **not available on Conda Forge**.
|
|
190
196
|
|
|
191
197
|
---
|
|
192
198
|
|
|
193
|
-
### Option 2:
|
|
199
|
+
### Option 2: Install from source (editable)
|
|
194
200
|
|
|
195
201
|
```bash
|
|
196
202
|
# Create and activate a virtual environment
|
|
197
|
-
python -m venv .venv
|
|
203
|
+
python -m venv .venv # or `uv venv`
|
|
198
204
|
source .venv/bin/activate # on macOS/Linux
|
|
199
205
|
.venv\Scripts\activate # on Windows
|
|
200
206
|
|
|
201
207
|
# Upgrade pip
|
|
202
208
|
pip install --upgrade pip
|
|
203
209
|
|
|
204
|
-
# Install in editable mode
|
|
210
|
+
# Install in editable mode for local development
|
|
205
211
|
pip install -e .[dev]
|
|
206
212
|
```
|
|
207
213
|
|
|
214
|
+
---
|
|
215
|
+
|
|
216
|
+
### Option 3: Conda environment + pip install
|
|
217
|
+
|
|
218
|
+
```bash
|
|
219
|
+
# Create and activate a Conda environment with Python 3.11
|
|
220
|
+
conda create -n cd_dynamax python=3.11
|
|
221
|
+
conda activate cd_dynamax
|
|
222
|
+
|
|
223
|
+
# Install latest release from PyPI
|
|
224
|
+
pip install cd-dynamax
|
|
225
|
+
```
|
|
226
|
+
|
|
208
227
|
#### GPU support
|
|
209
228
|
If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
|
|
210
229
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "cd_dynamax"
|
|
7
|
-
version = "0.3.
|
|
7
|
+
version = "0.3.3"
|
|
8
8
|
requires-python = ">=3.11"
|
|
9
9
|
description = "Continuous-discrete dynamical systems with JAX and related libraries."
|
|
10
10
|
readme = "README.md"
|
|
@@ -42,7 +42,8 @@ dependencies = [
|
|
|
42
42
|
"dm-tree>=0.1.8",
|
|
43
43
|
"fastprogress>=1.0.0",
|
|
44
44
|
"graphviz",
|
|
45
|
-
"ipykernel"
|
|
45
|
+
"ipykernel",
|
|
46
|
+
"orbax-checkpoint<0.11.3; sys_platform == 'win32'"
|
|
46
47
|
]
|
|
47
48
|
|
|
48
49
|
[tool.setuptools.packages.find]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jax.random as jr
|
|
3
|
+
from typing import NamedTuple
|
|
4
|
+
|
|
5
|
+
from cd_dynamax.src.continuous_discrete_nonlinear_gaussian_ssm import (
|
|
6
|
+
ContDiscreteNonlinearGaussianSSM,
|
|
7
|
+
)
|
|
8
|
+
from cd_dynamax.src.continuous_discrete_nonlinear_gaussian_ssm.inference_ekf import (
|
|
9
|
+
EKFHyperParams,
|
|
10
|
+
extended_kalman_filter,
|
|
11
|
+
)
|
|
12
|
+
from cd_dynamax.src.continuous_discrete_nonlinear_ssm import ContDiscreteNonlinearSSM
|
|
13
|
+
from cd_dynamax.src.continuous_discrete_nonlinear_ssm.inference_dpf import (
|
|
14
|
+
DPFHyperParams,
|
|
15
|
+
diff_particle_filter,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class StateDependentDiagonalDiffusion(NamedTuple):
|
|
20
|
+
scale: float = 0.1
|
|
21
|
+
|
|
22
|
+
def f(self, x, u=None, t=None):
|
|
23
|
+
if x is None:
|
|
24
|
+
raise ValueError("state-dependent diffusion requires a state argument")
|
|
25
|
+
x = jnp.atleast_1d(x)
|
|
26
|
+
return jnp.diag(1.0 + self.scale * jnp.square(x))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_cdnlgssm_sample_path_supports_state_dependent_diffusion():
|
|
30
|
+
model = ContDiscreteNonlinearGaussianSSM(state_dim=1, emission_dim=1)
|
|
31
|
+
params, _ = model.initialize()
|
|
32
|
+
params = params._replace(
|
|
33
|
+
dynamics=params.dynamics._replace(
|
|
34
|
+
diffusion_coefficient=StateDependentDiagonalDiffusion()
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
states, emissions = model.sample_path(params, key=jr.PRNGKey(0), num_timesteps=3)
|
|
39
|
+
|
|
40
|
+
assert states.shape == (3, 1)
|
|
41
|
+
assert emissions.shape == (3, 1)
|
|
42
|
+
assert jnp.all(jnp.isfinite(states))
|
|
43
|
+
assert jnp.all(jnp.isfinite(emissions))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_extended_kalman_filter_supports_state_dependent_diffusion():
|
|
47
|
+
model = ContDiscreteNonlinearGaussianSSM(state_dim=1, emission_dim=1)
|
|
48
|
+
params, _ = model.initialize()
|
|
49
|
+
params = params._replace(
|
|
50
|
+
dynamics=params.dynamics._replace(
|
|
51
|
+
diffusion_coefficient=StateDependentDiagonalDiffusion()
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
emissions = jnp.zeros((2, 1))
|
|
56
|
+
t_emissions = jnp.array([[0.0], [0.1]])
|
|
57
|
+
|
|
58
|
+
posterior = extended_kalman_filter(
|
|
59
|
+
params,
|
|
60
|
+
emissions,
|
|
61
|
+
t_emissions=t_emissions,
|
|
62
|
+
filter_hyperparams=EKFHyperParams(state_order="first"),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
assert posterior.filtered_means.shape == (2, 1)
|
|
66
|
+
assert posterior.predicted_covariances.shape == (2, 1, 1)
|
|
67
|
+
assert jnp.all(jnp.isfinite(posterior.filtered_means))
|
|
68
|
+
assert jnp.all(jnp.isfinite(posterior.predicted_covariances))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_diff_particle_filter_supports_state_dependent_diffusion():
|
|
72
|
+
model = ContDiscreteNonlinearSSM(state_dim=1, emission_dim=1)
|
|
73
|
+
params, _ = model.initialize()
|
|
74
|
+
params = params._replace(
|
|
75
|
+
dynamics=params.dynamics._replace(
|
|
76
|
+
diffusion_coefficient=StateDependentDiagonalDiffusion()
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
emissions = jnp.zeros((2, 1))
|
|
81
|
+
ts = jnp.array([0.0, 0.1])
|
|
82
|
+
|
|
83
|
+
particles, log_weights, log_evidence = diff_particle_filter(
|
|
84
|
+
jr.PRNGKey(0),
|
|
85
|
+
params,
|
|
86
|
+
emissions,
|
|
87
|
+
ts=ts,
|
|
88
|
+
hyperparams=DPFHyperParams(N_particles=8),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
assert particles.shape == (2, 8, 1)
|
|
92
|
+
assert log_weights.shape == (2, 8)
|
|
93
|
+
assert jnp.isfinite(log_evidence)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py
RENAMED
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/inference_test.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/initial.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py
RENAMED
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py
RENAMED
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|