cd-dynamax 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/PKG-INFO +46 -27
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/README.md +44 -26
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +76 -3
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +6 -2
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +7 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +81 -27
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +28 -2
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +72 -20
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +102 -29
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +87 -36
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +60 -19
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +15 -15
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +22 -4
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/demo_utils.py +7 -7
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/PKG-INFO +46 -27
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/SOURCES.txt +3 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/requires.txt +3 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/pyproject.toml +3 -2
- cd_dynamax-0.3.4/tests/test_linear_gaussian_filter_posterior_extras.py +107 -0
- cd_dynamax-0.3.4/tests/test_nonlinear_gaussian_filter_posterior_extras.py +237 -0
- cd_dynamax-0.3.4/tests/test_state_dependent_diffusion.py +93 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/LICENSE +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/_version.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/initial.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/parameters.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/parameters_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/inference.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/inference_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/ssm.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/types.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/bijectors.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/distributions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/distributions_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/optimize.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/plotting.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/utils_test.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/warnings.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/ssm_temissions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/__init__.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/data_driven_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/data_generator.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/debug_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/diffrax_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/evaluation_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/experiment_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/likelihood_eval_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/optimize_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/physics_based_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/plotting_chaos_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/plotting_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/prior_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/simulation_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/test_utils.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/dependency_links.txt +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/top_level.txt +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/setup.cfg +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_cdlgssm_dlgssm_match.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_cdnonlinear_cdlinear_match.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_filter_forecast_emissions.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_imports.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_models.py +0 -0
- {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_utils_imports.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cd_dynamax
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: Continuous-discrete dynamical systems with JAX and related libraries.
|
|
5
5
|
Author: Matthew Levine, Iñigo Urteaga
|
|
6
6
|
Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
|
|
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
|
|
|
32
32
|
Requires-Dist: fastprogress>=1.0.0
|
|
33
33
|
Requires-Dist: graphviz
|
|
34
34
|
Requires-Dist: ipykernel
|
|
35
|
+
Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
|
|
35
36
|
Provides-Extra: dev
|
|
36
37
|
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
37
38
|
Requires-Dist: ruff; extra == "dev"
|
|
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
|
|
|
69
70
|
|
|
70
71
|
where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
|
|
71
72
|
|
|
72
|
-
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\
|
|
73
|
+
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
|
|
73
74
|
|
|
74
75
|
Note:
|
|
75
76
|
|
|
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
|
|
|
90
91
|
|
|
91
92
|
All of these problems are deeply interconnected.
|
|
92
93
|
|
|
93
|
-
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]
|
|
94
|
+
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
|
|
94
95
|
|
|
95
96
|
- In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
|
|
96
97
|
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
|
|
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
|
|
|
113
114
|
|
|
114
115
|
- The codebase is organized into several key directories:
|
|
115
116
|
```
|
|
116
|
-
|
|
117
|
-
├──
|
|
118
|
-
│ ├──
|
|
119
|
-
│ ├──
|
|
120
|
-
│ ├──
|
|
121
|
-
│
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
├──
|
|
125
|
-
├── python/
|
|
126
|
-
├── python/
|
|
127
|
-
|
|
117
|
+
.
|
|
118
|
+
├── cd_dynamax/ # Source code for cd-dynamax library
|
|
119
|
+
│ ├── src/ # Core source code
|
|
120
|
+
│ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
|
|
121
|
+
│ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
|
|
122
|
+
│ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
|
|
123
|
+
│ │ └── utils/ # Utility functions and example models
|
|
124
|
+
│ └── dynamax/ # Original dynamax library (as a submodule)
|
|
125
|
+
├── demos/ # Python demos showcasing cd-dynamax functionality
|
|
126
|
+
│ ├── python/scripts/ # Python scripts for running demos
|
|
127
|
+
│ ├── python/notebooks/ # Jupyter notebooks for interactive demos
|
|
128
|
+
│ └── python/configs/ # Configuration files for demos
|
|
129
|
+
└── tests/ # Tests for cd-dynamax functionality
|
|
128
130
|
```
|
|
129
131
|
|
|
130
132
|
## [Demos](./demos/python)
|
|
@@ -155,7 +157,7 @@ make test
|
|
|
155
157
|
```
|
|
156
158
|
|
|
157
159
|
- For linting, we use `ruff`:
|
|
158
|
-
```
|
|
160
|
+
```bash
|
|
159
161
|
make lint
|
|
160
162
|
```
|
|
161
163
|
|
|
@@ -171,40 +173,57 @@ make build_docs
|
|
|
171
173
|
|
|
172
174
|
# Installation
|
|
173
175
|
|
|
174
|
-
|
|
176
|
+
Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
|
|
175
177
|
|
|
176
178
|
---
|
|
177
179
|
|
|
178
|
-
### Option 1:
|
|
180
|
+
### Option 1: Install from PyPI (recommended)
|
|
179
181
|
|
|
180
182
|
```bash
|
|
181
|
-
# Create and activate a
|
|
182
|
-
|
|
183
|
-
|
|
183
|
+
# Create and activate a virtual environment
|
|
184
|
+
python -m venv .venv # or `uv venv`
|
|
185
|
+
source .venv/bin/activate # on macOS/Linux
|
|
186
|
+
.venv\Scripts\activate # on Windows
|
|
184
187
|
|
|
185
|
-
#
|
|
186
|
-
pip install
|
|
188
|
+
# Upgrade pip
|
|
189
|
+
pip install --upgrade pip
|
|
190
|
+
|
|
191
|
+
# Install latest release from PyPI
|
|
192
|
+
pip install cd-dynamax
|
|
187
193
|
```
|
|
188
194
|
|
|
189
|
-
|
|
195
|
+
`cd-dynamax` is currently **not available on Conda Forge**.
|
|
190
196
|
|
|
191
197
|
---
|
|
192
198
|
|
|
193
|
-
### Option 2:
|
|
199
|
+
### Option 2: Install from source (editable)
|
|
194
200
|
|
|
195
201
|
```bash
|
|
196
202
|
# Create and activate a virtual environment
|
|
197
|
-
python -m venv .venv
|
|
203
|
+
python -m venv .venv # or `uv venv`
|
|
198
204
|
source .venv/bin/activate # on macOS/Linux
|
|
199
205
|
.venv\Scripts\activate # on Windows
|
|
200
206
|
|
|
201
207
|
# Upgrade pip
|
|
202
208
|
pip install --upgrade pip
|
|
203
209
|
|
|
204
|
-
# Install in editable mode
|
|
210
|
+
# Install in editable mode for local development
|
|
205
211
|
pip install -e .[dev]
|
|
206
212
|
```
|
|
207
213
|
|
|
214
|
+
---
|
|
215
|
+
|
|
216
|
+
### Option 3: Conda environment + pip install
|
|
217
|
+
|
|
218
|
+
```bash
|
|
219
|
+
# Create and activate a Conda environment with Python 3.11
|
|
220
|
+
conda create -n cd_dynamax python=3.11
|
|
221
|
+
conda activate cd_dynamax
|
|
222
|
+
|
|
223
|
+
# Install latest release from PyPI
|
|
224
|
+
pip install cd-dynamax
|
|
225
|
+
```
|
|
226
|
+
|
|
208
227
|
#### GPU support
|
|
209
228
|
If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
|
|
210
229
|
|
|
@@ -25,7 +25,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
|
|
|
25
25
|
|
|
26
26
|
where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
|
|
27
27
|
|
|
28
|
-
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\
|
|
28
|
+
We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
|
|
29
29
|
|
|
30
30
|
Note:
|
|
31
31
|
|
|
@@ -46,7 +46,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
|
|
|
46
46
|
|
|
47
47
|
All of these problems are deeply interconnected.
|
|
48
48
|
|
|
49
|
-
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]
|
|
49
|
+
- In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
|
|
50
50
|
|
|
51
51
|
- In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
|
|
52
52
|
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
|
|
@@ -69,18 +69,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
|
|
|
69
69
|
|
|
70
70
|
- The codebase is organized into several key directories:
|
|
71
71
|
```
|
|
72
|
-
|
|
73
|
-
├──
|
|
74
|
-
│ ├──
|
|
75
|
-
│ ├──
|
|
76
|
-
│ ├──
|
|
77
|
-
│
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
├──
|
|
81
|
-
├── python/
|
|
82
|
-
├── python/
|
|
83
|
-
|
|
72
|
+
.
|
|
73
|
+
├── cd_dynamax/ # Source code for cd-dynamax library
|
|
74
|
+
│ ├── src/ # Core source code
|
|
75
|
+
│ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
|
|
76
|
+
│ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
|
|
77
|
+
│ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
|
|
78
|
+
│ │ └── utils/ # Utility functions and example models
|
|
79
|
+
│ └── dynamax/ # Original dynamax library (as a submodule)
|
|
80
|
+
├── demos/ # Python demos showcasing cd-dynamax functionality
|
|
81
|
+
│ ├── python/scripts/ # Python scripts for running demos
|
|
82
|
+
│ ├── python/notebooks/ # Jupyter notebooks for interactive demos
|
|
83
|
+
│ └── python/configs/ # Configuration files for demos
|
|
84
|
+
└── tests/ # Tests for cd-dynamax functionality
|
|
84
85
|
```
|
|
85
86
|
|
|
86
87
|
## [Demos](./demos/python)
|
|
@@ -111,7 +112,7 @@ make test
|
|
|
111
112
|
```
|
|
112
113
|
|
|
113
114
|
- For linting, we use `ruff`:
|
|
114
|
-
```
|
|
115
|
+
```bash
|
|
115
116
|
make lint
|
|
116
117
|
```
|
|
117
118
|
|
|
@@ -127,40 +128,57 @@ make build_docs
|
|
|
127
128
|
|
|
128
129
|
# Installation
|
|
129
130
|
|
|
130
|
-
|
|
131
|
+
Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
|
|
131
132
|
|
|
132
133
|
---
|
|
133
134
|
|
|
134
|
-
### Option 1:
|
|
135
|
+
### Option 1: Install from PyPI (recommended)
|
|
135
136
|
|
|
136
137
|
```bash
|
|
137
|
-
# Create and activate a
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
# Create and activate a virtual environment
|
|
139
|
+
python -m venv .venv # or `uv venv`
|
|
140
|
+
source .venv/bin/activate # on macOS/Linux
|
|
141
|
+
.venv\Scripts\activate # on Windows
|
|
140
142
|
|
|
141
|
-
#
|
|
142
|
-
pip install
|
|
143
|
+
# Upgrade pip
|
|
144
|
+
pip install --upgrade pip
|
|
145
|
+
|
|
146
|
+
# Install latest release from PyPI
|
|
147
|
+
pip install cd-dynamax
|
|
143
148
|
```
|
|
144
149
|
|
|
145
|
-
|
|
150
|
+
`cd-dynamax` is currently **not available on Conda Forge**.
|
|
146
151
|
|
|
147
152
|
---
|
|
148
153
|
|
|
149
|
-
### Option 2:
|
|
154
|
+
### Option 2: Install from source (editable)
|
|
150
155
|
|
|
151
156
|
```bash
|
|
152
157
|
# Create and activate a virtual environment
|
|
153
|
-
python -m venv .venv
|
|
158
|
+
python -m venv .venv # or `uv venv`
|
|
154
159
|
source .venv/bin/activate # on macOS/Linux
|
|
155
160
|
.venv\Scripts\activate # on Windows
|
|
156
161
|
|
|
157
162
|
# Upgrade pip
|
|
158
163
|
pip install --upgrade pip
|
|
159
164
|
|
|
160
|
-
# Install in editable mode
|
|
165
|
+
# Install in editable mode for local development
|
|
161
166
|
pip install -e .[dev]
|
|
162
167
|
```
|
|
163
168
|
|
|
169
|
+
---
|
|
170
|
+
|
|
171
|
+
### Option 3: Conda environment + pip install
|
|
172
|
+
|
|
173
|
+
```bash
|
|
174
|
+
# Create and activate a Conda environment with Python 3.11
|
|
175
|
+
conda create -n cd_dynamax python=3.11
|
|
176
|
+
conda activate cd_dynamax
|
|
177
|
+
|
|
178
|
+
# Install latest release from PyPI
|
|
179
|
+
pip install cd-dynamax
|
|
180
|
+
```
|
|
181
|
+
|
|
164
182
|
#### GPU support
|
|
165
183
|
If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
|
|
166
184
|
|
|
@@ -11,7 +11,7 @@ from tensorflow_probability.substrates.jax.distributions import (
|
|
|
11
11
|
|
|
12
12
|
from jax.tree_util import tree_map
|
|
13
13
|
from jaxtyping import Array, Float
|
|
14
|
-
from typing import NamedTuple, Optional, Union, Tuple
|
|
14
|
+
from typing import Collection, NamedTuple, Optional, Union, Tuple
|
|
15
15
|
from ..utils.utils import psd_solve, symmetrize
|
|
16
16
|
from ..parameters import ParameterProperties
|
|
17
17
|
from ..types import PRNGKey, Scalar
|
|
@@ -112,9 +112,27 @@ class ParamsLGSSM(NamedTuple):
|
|
|
112
112
|
class PosteriorGSSMFiltered(NamedTuple):
|
|
113
113
|
r"""Marginals of the Gaussian filtering posterior.
|
|
114
114
|
|
|
115
|
+
Time indexing convention:
|
|
116
|
+
`filtered_*[t]` corresponds to $x_{t \mid t}$,
|
|
117
|
+
`predicted_*[t]` corresponds to $x_{t+1 \mid t}$,
|
|
118
|
+
and predictive observation quantities such as `y_pred_*[t]` and
|
|
119
|
+
`y_obs_pred_*[t]` correspond to the observation-time prior
|
|
120
|
+
$p(y_t \mid y_{1:t-1})$, i.e. they are computed from $x_{t \mid t-1}$.
|
|
121
|
+
|
|
115
122
|
:param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$
|
|
116
|
-
:param filtered_means: array of filtered means $\mathbb{E}[
|
|
117
|
-
:param filtered_covariances: array of filtered covariances $\mathrm{Cov}[
|
|
123
|
+
:param filtered_means: array of filtered means $\mathbb{E}[x_t \mid y_{1:t}, u_{1:t}]$
|
|
124
|
+
:param filtered_covariances: array of filtered covariances $\mathrm{Cov}[x_t \mid y_{1:t}, u_{1:t}]$
|
|
125
|
+
:param predicted_means: optional array of one-step-ahead state means $\mathbb{E}[x_{t+1} \mid y_{1:t}, u_{1:t}]$
|
|
126
|
+
:param predicted_covariances: optional array of one-step-ahead state covariances $\mathrm{Cov}[x_{t+1} \mid y_{1:t}, u_{1:t}]$
|
|
127
|
+
:param filtered_ensembles: When filtering with EnKF, optional array of filtered state ensembles approximating $p(x_t \mid y_{1:t}, u_{1:t})$
|
|
128
|
+
:param predicted_ensembles: When filtering with EnKF, optional array of one-step-ahead predicted state ensembles approximating $p(x_{t+1} \mid y_{1:t}, u_{1:t})$
|
|
129
|
+
:param y_pred_mean: optional array of predictive emission means for $p(h(x_t) \mid y_{1:t-1}, u_{1:t})$, i.e., before adding observation noise to h(x_t)
|
|
130
|
+
:param y_pred_cov: optional array of predictive emission covariances for $p(h(x_t) \mid y_{1:t-1}, u_{1:t})$, i.e., before adding observation noise to h(x_t)
|
|
131
|
+
:param y_obs_pred_mean: optional array of predictive observation means for $p(y_t \mid y_{1:t-1}, u_{1:t})$
|
|
132
|
+
:param y_obs_pred_cov: optional array of predictive observation covariances for $p(y_t \mid y_{1:t-1}, u_{1:t})$, including observation noise
|
|
133
|
+
:param y_ens_pred: When filtering with EnKF, optional array of predictive emission ensembles for observation time $t$, obtained from the predicted state ensemble for $x_{t \mid t-1}$
|
|
134
|
+
:param y_obs_ens_pred: When filtering with EnKF, optional array of predictive observation ensembles for observation time $t$, obtained by adding sampled observation noise to `y_ens_pred`
|
|
135
|
+
:param posterior_extras: When filtering with EnKF, optional dictionary of filter-specific per-step diagnostics that do not fit the standard filtered posterior schema
|
|
118
136
|
|
|
119
137
|
"""
|
|
120
138
|
# Default attributes
|
|
@@ -123,9 +141,64 @@ class PosteriorGSSMFiltered(NamedTuple):
|
|
|
123
141
|
filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
|
|
124
142
|
predicted_means: Optional[Float[Array, "ntime state_dim"]] = None
|
|
125
143
|
predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
|
|
144
|
+
filtered_ensembles: Optional[Float[Array, "ntime ensemble_size state_dim"]] = None
|
|
145
|
+
predicted_ensembles: Optional[Float[Array, "ntime ensemble_size state_dim"]] = None
|
|
146
|
+
y_pred_mean: Optional[Float[Array, "ntime emission_dim"]] = None
|
|
147
|
+
y_pred_cov: Optional[Float[Array, "ntime emission_dim emission_dim"]] = None
|
|
148
|
+
y_obs_pred_mean: Optional[Float[Array, "ntime emission_dim"]] = None
|
|
149
|
+
y_obs_pred_cov: Optional[Float[Array, "ntime emission_dim emission_dim"]] = None
|
|
150
|
+
y_ens_pred: Optional[Float[Array, "ntime ensemble_size emission_dim"]] = None
|
|
151
|
+
y_obs_ens_pred: Optional[Float[Array, "ntime ensemble_size emission_dim"]] = None
|
|
126
152
|
# Additional extras
|
|
127
153
|
posterior_extras: Optional[dict] = None
|
|
128
154
|
|
|
155
|
+
|
|
156
|
+
FILTERED_POSTERIOR_FIELD_NAMES = (
|
|
157
|
+
"marginal_loglik",
|
|
158
|
+
"filtered_means",
|
|
159
|
+
"filtered_covariances",
|
|
160
|
+
"predicted_means",
|
|
161
|
+
"predicted_covariances",
|
|
162
|
+
"filtered_ensembles",
|
|
163
|
+
"predicted_ensembles",
|
|
164
|
+
"y_pred_mean",
|
|
165
|
+
"y_pred_cov",
|
|
166
|
+
"y_obs_pred_mean",
|
|
167
|
+
"y_obs_pred_cov",
|
|
168
|
+
"y_ens_pred",
|
|
169
|
+
"y_obs_ens_pred",
|
|
170
|
+
"posterior_extras",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def validate_filtered_posterior_output_fields(
|
|
175
|
+
filter_name: str,
|
|
176
|
+
output_fields: Optional[Collection[str]],
|
|
177
|
+
supported_fields: Collection[str],
|
|
178
|
+
):
|
|
179
|
+
"""Validate requested top-level filtered posterior fields."""
|
|
180
|
+
|
|
181
|
+
if output_fields is None:
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
valid_fields = set(FILTERED_POSTERIOR_FIELD_NAMES)
|
|
185
|
+
supported_fields = set(supported_fields)
|
|
186
|
+
requested_fields = list(output_fields)
|
|
187
|
+
|
|
188
|
+
unknown_fields = sorted(set(requested_fields) - valid_fields)
|
|
189
|
+
if unknown_fields:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Unknown output_fields for {filter_name}: {unknown_fields}. "
|
|
192
|
+
f"Valid top-level filtered posterior fields are: {sorted(valid_fields)}."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
unsupported_fields = sorted(set(requested_fields) - supported_fields)
|
|
196
|
+
if unsupported_fields:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"output_fields {unsupported_fields} are not available for {filter_name}. "
|
|
199
|
+
f"Available fields are: {sorted(supported_fields)}."
|
|
200
|
+
)
|
|
201
|
+
|
|
129
202
|
class PosteriorGSSMSmoothed(NamedTuple):
|
|
130
203
|
r"""Marginals of the Gaussian filtering and smoothing posterior.
|
|
131
204
|
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py
RENAMED
|
@@ -101,8 +101,12 @@ def extended_kalman_filter(
|
|
|
101
101
|
num_iter: number of linearizations around posterior for update step (default 1).
|
|
102
102
|
inputs: optional array of inputs.
|
|
103
103
|
output_fields: list of fields to return in posterior object.
|
|
104
|
-
|
|
105
|
-
"
|
|
104
|
+
Options:
|
|
105
|
+
`"filtered_means"` (default)
|
|
106
|
+
`"filtered_covariances"` (default)
|
|
107
|
+
`"predicted_means"` (default)
|
|
108
|
+
`"predicted_covariances"` (default)
|
|
109
|
+
`"marginal_loglik"`
|
|
106
110
|
|
|
107
111
|
Returns:
|
|
108
112
|
post: posterior object.
|
{cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py
RENAMED
|
@@ -151,6 +151,13 @@ def unscented_kalman_filter(
|
|
|
151
151
|
emissions: array of observations.
|
|
152
152
|
hyperparams: hyper-parameters.
|
|
153
153
|
inputs: optional array of inputs.
|
|
154
|
+
output_fields: list of fields to return in posterior object.
|
|
155
|
+
Options:
|
|
156
|
+
`"filtered_means"` (default)
|
|
157
|
+
`"filtered_covariances"` (default)
|
|
158
|
+
`"predicted_means"` (default)
|
|
159
|
+
`"predicted_covariances"` (default)
|
|
160
|
+
`"marginal_loglik"`
|
|
154
161
|
|
|
155
162
|
Returns:
|
|
156
163
|
filtered_posterior: posterior object.
|
|
@@ -27,6 +27,7 @@ from cd_dynamax.dynamax.utils.utils import psd_solve
|
|
|
27
27
|
from cd_dynamax.dynamax.linear_gaussian_ssm.inference import (
|
|
28
28
|
PosteriorGSSMFiltered,
|
|
29
29
|
PosteriorGSSMSmoothed,
|
|
30
|
+
validate_filtered_posterior_output_fields,
|
|
30
31
|
)
|
|
31
32
|
|
|
32
33
|
# Initial and emission parameter classes are equivalent
|
|
@@ -54,6 +55,19 @@ tfb = tfp.bijectors
|
|
|
54
55
|
DEBUG = False
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
CDLGSSM_FILTER_OUTPUT_FIELDS = (
|
|
59
|
+
"marginal_loglik",
|
|
60
|
+
"filtered_means",
|
|
61
|
+
"filtered_covariances",
|
|
62
|
+
"predicted_means",
|
|
63
|
+
"predicted_covariances",
|
|
64
|
+
"y_pred_mean",
|
|
65
|
+
"y_pred_cov",
|
|
66
|
+
"y_obs_pred_mean",
|
|
67
|
+
"y_obs_pred_cov",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
57
71
|
#### Helper functions
|
|
58
72
|
# Helper functions --- modified from dynamax
|
|
59
73
|
def _get_params(x, dim, t):
|
|
@@ -137,20 +151,27 @@ def preprocess_args(f):
|
|
|
137
151
|
filter_hyperparams = bound_args.arguments["filter_hyperparams"]
|
|
138
152
|
inputs = bound_args.arguments["inputs"]
|
|
139
153
|
warn = bound_args.arguments["warn"]
|
|
154
|
+
output_fields = bound_args.arguments.get("output_fields")
|
|
140
155
|
|
|
141
156
|
num_timesteps = len(emissions)
|
|
142
157
|
full_params, inputs = preprocess_params_and_inputs(
|
|
143
158
|
params, num_timesteps, inputs
|
|
144
159
|
)
|
|
145
160
|
|
|
146
|
-
|
|
147
|
-
full_params,
|
|
148
|
-
emissions,
|
|
149
|
-
t_emissions,
|
|
161
|
+
call_kwargs = dict(
|
|
162
|
+
params=full_params,
|
|
163
|
+
emissions=emissions,
|
|
164
|
+
t_emissions=t_emissions,
|
|
150
165
|
filter_hyperparams=filter_hyperparams,
|
|
151
166
|
inputs=inputs,
|
|
152
167
|
warn=warn,
|
|
153
168
|
)
|
|
169
|
+
if output_fields is not None:
|
|
170
|
+
call_kwargs["output_fields"] = output_fields
|
|
171
|
+
|
|
172
|
+
return f(
|
|
173
|
+
**call_kwargs,
|
|
174
|
+
)
|
|
154
175
|
|
|
155
176
|
return wrapper
|
|
156
177
|
|
|
@@ -334,6 +355,14 @@ def _condition_on(m, P, H, D, d, R, u, y, warn: bool = True):
|
|
|
334
355
|
return mu_cond, Sigma_cond
|
|
335
356
|
|
|
336
357
|
|
|
358
|
+
def _emission_predicted_moments(m, P, H, D, d, u, warn: bool = True):
|
|
359
|
+
"""Compute the predictive emission moments under the linear Gaussian model."""
|
|
360
|
+
|
|
361
|
+
y_pred_mean = H @ m + D @ u + d
|
|
362
|
+
y_pred_cov = psd(H @ P @ H.T, warn=warn)
|
|
363
|
+
return y_pred_mean, y_pred_cov
|
|
364
|
+
|
|
365
|
+
|
|
337
366
|
# CD-LGSSM filtering implementation: Kalman filter
|
|
338
367
|
@preprocess_args
|
|
339
368
|
def cdlgssm_filter(
|
|
@@ -342,6 +371,12 @@ def cdlgssm_filter(
|
|
|
342
371
|
t_emissions: Optional[Float[Array, "num_timesteps 1"]] = None,
|
|
343
372
|
filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(),
|
|
344
373
|
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None,
|
|
374
|
+
output_fields: Optional[List[str]] = [
|
|
375
|
+
"filtered_means",
|
|
376
|
+
"filtered_covariances",
|
|
377
|
+
"predicted_means",
|
|
378
|
+
"predicted_covariances",
|
|
379
|
+
],
|
|
345
380
|
warn: bool = True,
|
|
346
381
|
) -> PosteriorGSSMFiltered:
|
|
347
382
|
r"""Run a Continuous Discrete Kalman filter
|
|
@@ -353,6 +388,17 @@ def cdlgssm_filter(
|
|
|
353
388
|
t_emissions: continuous-time specific time instants of observations: if not None, it is an array
|
|
354
389
|
filter_hyperparams: hyperparameters for the filter.
|
|
355
390
|
inputs: optional array of inputs.
|
|
391
|
+
output_fields: list of top-level posterior fields to return.
|
|
392
|
+
Options:
|
|
393
|
+
`"filtered_means"` (default)
|
|
394
|
+
`"filtered_covariances"` (default)
|
|
395
|
+
`"predicted_means"` (default)
|
|
396
|
+
`"predicted_covariances"` (default)
|
|
397
|
+
`"marginal_loglik"`
|
|
398
|
+
`"y_pred_mean"`
|
|
399
|
+
`"y_pred_cov"`
|
|
400
|
+
`"y_obs_pred_mean"`
|
|
401
|
+
`"y_obs_pred_cov"`
|
|
356
402
|
warn: whether to issue warnings during filtering (e.g., PSD issues).
|
|
357
403
|
|
|
358
404
|
Returns:
|
|
@@ -362,6 +408,10 @@ def cdlgssm_filter(
|
|
|
362
408
|
if filter_hyperparams is None:
|
|
363
409
|
filter_hyperparams = KFHyperParams()
|
|
364
410
|
|
|
411
|
+
validate_filtered_posterior_output_fields(
|
|
412
|
+
"cdlgssm_filter", output_fields, CDLGSSM_FILTER_OUTPUT_FIELDS
|
|
413
|
+
)
|
|
414
|
+
|
|
365
415
|
# Figure out timestamps, as vectors to scan over
|
|
366
416
|
# t_emissions is of shape num_timesteps \times 1
|
|
367
417
|
# t0 and t1 are num_timesteps \times 0
|
|
@@ -405,8 +455,14 @@ def cdlgssm_filter(
|
|
|
405
455
|
u = inputs[t0_idx]
|
|
406
456
|
y = emissions[t0_idx]
|
|
407
457
|
|
|
458
|
+
y_pred_mean, y_pred_cov = _emission_predicted_moments(
|
|
459
|
+
pred_mean, pred_cov, H, D, d, u, warn=warn
|
|
460
|
+
)
|
|
461
|
+
y_obs_pred_mean = y_pred_mean
|
|
462
|
+
y_obs_pred_cov = psd(y_pred_cov + R, warn=warn)
|
|
463
|
+
|
|
408
464
|
# Update the log likelihood
|
|
409
|
-
ll += MVN(
|
|
465
|
+
ll += MVN(y_obs_pred_mean, y_obs_pred_cov).log_prob(y)
|
|
410
466
|
|
|
411
467
|
# Condition on this emission
|
|
412
468
|
filtered_mean, filtered_cov = _condition_on(
|
|
@@ -427,30 +483,30 @@ def cdlgssm_filter(
|
|
|
427
483
|
filtered_mean, filtered_cov, F, C, B, b, Q, u, warn=warn
|
|
428
484
|
)
|
|
429
485
|
|
|
486
|
+
outputs = {
|
|
487
|
+
"filtered_means": filtered_mean,
|
|
488
|
+
"filtered_covariances": filtered_cov,
|
|
489
|
+
"predicted_means": pred_mean,
|
|
490
|
+
"predicted_covariances": pred_cov,
|
|
491
|
+
"y_pred_mean": y_pred_mean,
|
|
492
|
+
"y_pred_cov": y_pred_cov,
|
|
493
|
+
"y_obs_pred_mean": y_obs_pred_mean,
|
|
494
|
+
"y_obs_pred_cov": y_obs_pred_cov,
|
|
495
|
+
}
|
|
496
|
+
outputs = {key: val for key, val in outputs.items() if key in output_fields}
|
|
497
|
+
|
|
430
498
|
# Return the carry and outputs
|
|
431
|
-
return (ll, pred_mean, pred_cov),
|
|
432
|
-
filtered_mean,
|
|
433
|
-
filtered_cov,
|
|
434
|
-
pred_mean,
|
|
435
|
-
pred_cov,
|
|
436
|
-
)
|
|
499
|
+
return (ll, pred_mean, pred_cov), outputs
|
|
437
500
|
|
|
438
501
|
# The Kalman filter
|
|
439
502
|
# Initial carry
|
|
440
503
|
carry = (0.0, params.initial.mean, params.initial.cov)
|
|
441
504
|
# Scan over all time steps
|
|
442
|
-
(ll, _, _), (
|
|
443
|
-
|
|
444
|
-
)
|
|
505
|
+
(ll, _, _), outputs = lax.scan(_step, carry, (t0, t1, t0_idx))
|
|
506
|
+
outputs = {"marginal_loglik": ll, **outputs}
|
|
445
507
|
|
|
446
508
|
# Return the posterior object
|
|
447
|
-
return PosteriorGSSMFiltered(
|
|
448
|
-
marginal_loglik=ll,
|
|
449
|
-
filtered_means=filtered_means,
|
|
450
|
-
filtered_covariances=filtered_covs,
|
|
451
|
-
predicted_means=pred_means,
|
|
452
|
-
predicted_covariances=pred_covs,
|
|
453
|
-
)
|
|
509
|
+
return PosteriorGSSMFiltered(**outputs)
|
|
454
510
|
|
|
455
511
|
|
|
456
512
|
# Kalmam Smoothing equations, in continuous-time
|
|
@@ -1110,12 +1166,10 @@ def cdlgssm_forecast(
|
|
|
1110
1166
|
inputs: optional array of inputs, of shape (1 + num_timesteps) \times input_dim
|
|
1111
1167
|
- The extra input is needed for the initial emission, i.e., it should be at time t_init
|
|
1112
1168
|
output_fields: list of fields to return in posterior object.
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
If we forecast paths, based on solving the SDE
|
|
1118
|
-
"forecasted_state_path".
|
|
1169
|
+
Options:
|
|
1170
|
+
`"forecasted_state_means"` (default for Gaussian forecasts)
|
|
1171
|
+
`"forecasted_state_covariances"` (default for Gaussian forecasts)
|
|
1172
|
+
`"forecasted_state_path"` (for point-initialized path forecasts)
|
|
1119
1173
|
key: random key (e.g., for Ensemble Kalman).
|
|
1120
1174
|
diffeqsolve_settings: settings for the SDE solver
|
|
1121
1175
|
warn: whether to issue warnings during filtering (e.g., PSD issues).
|
|
@@ -4,7 +4,7 @@ import jax.random as jr
|
|
|
4
4
|
from jaxtyping import Array, Float
|
|
5
5
|
|
|
6
6
|
# Type annotations
|
|
7
|
-
from typing import Any, Optional, Tuple, Union
|
|
7
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
8
8
|
from typing_extensions import Protocol
|
|
9
9
|
|
|
10
10
|
# Distributions, compatible with JAX, from TensorFlow Probability
|
|
@@ -510,6 +510,12 @@ class ContDiscreteLinearGaussianSSM(SSM):
|
|
|
510
510
|
t_emissions: Optional[Float[Array, "ntime 1"]] = None,
|
|
511
511
|
filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(),
|
|
512
512
|
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
|
|
513
|
+
output_fields: Optional[List[str]] = [
|
|
514
|
+
"filtered_means",
|
|
515
|
+
"filtered_covariances",
|
|
516
|
+
"predicted_means",
|
|
517
|
+
"predicted_covariances",
|
|
518
|
+
],
|
|
513
519
|
warn: bool = True,
|
|
514
520
|
) -> PosteriorGSSMFiltered:
|
|
515
521
|
r"""Run the CD-Kalman filter to compute the filtered posterior distribution over states.
|
|
@@ -519,6 +525,20 @@ class ContDiscreteLinearGaussianSSM(SSM):
|
|
|
519
525
|
t_emissions: continuous-time specific time instants of observations: if not None, it is an array
|
|
520
526
|
filter_hyperparams: hyperparameters for the Kalman filter.
|
|
521
527
|
inputs: optional sequence of inputs.
|
|
528
|
+
output_fields: Which top-level posterior fields to return from the
|
|
529
|
+
filter. Options:
|
|
530
|
+
`"filtered_means"` (default)
|
|
531
|
+
`"filtered_covariances"` (default)
|
|
532
|
+
`"predicted_means"` (default)
|
|
533
|
+
`"predicted_covariances"` (default)
|
|
534
|
+
`"marginal_loglik"`
|
|
535
|
+
`"y_pred_mean"`
|
|
536
|
+
`"y_pred_cov"`
|
|
537
|
+
`"y_obs_pred_mean"`
|
|
538
|
+
`"y_obs_pred_cov"`
|
|
539
|
+
Predictive emission fields are available directly as
|
|
540
|
+
top-level filtered posterior outputs. Unsupported fields raise
|
|
541
|
+
a `ValueError`.
|
|
522
542
|
warn: whether to warn about numerical issues.
|
|
523
543
|
Returns:
|
|
524
544
|
filtered posterior distribution over states.
|
|
@@ -526,7 +546,13 @@ class ContDiscreteLinearGaussianSSM(SSM):
|
|
|
526
546
|
|
|
527
547
|
# Directly run the CD-Kalman filter
|
|
528
548
|
return cdlgssm_filter(
|
|
529
|
-
params,
|
|
549
|
+
params,
|
|
550
|
+
emissions,
|
|
551
|
+
t_emissions,
|
|
552
|
+
filter_hyperparams,
|
|
553
|
+
inputs,
|
|
554
|
+
output_fields=output_fields,
|
|
555
|
+
warn=warn,
|
|
530
556
|
)
|
|
531
557
|
|
|
532
558
|
# High-level, user-friendly interface combining filtering and forecasting steps
|