cd-dynamax 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/PKG-INFO +46 -27
  2. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/README.md +44 -26
  3. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +76 -3
  4. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +6 -2
  5. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +7 -0
  6. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +81 -27
  7. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +28 -2
  8. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +72 -20
  9. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +102 -29
  10. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +87 -36
  11. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +60 -19
  12. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +15 -15
  13. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +22 -4
  14. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/demo_utils.py +7 -7
  15. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/PKG-INFO +46 -27
  16. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/SOURCES.txt +3 -0
  17. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/requires.txt +3 -0
  18. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/pyproject.toml +3 -2
  19. cd_dynamax-0.3.4/tests/test_linear_gaussian_filter_posterior_extras.py +107 -0
  20. cd_dynamax-0.3.4/tests/test_nonlinear_gaussian_filter_posterior_extras.py +237 -0
  21. cd_dynamax-0.3.4/tests/test_state_dependent_diffusion.py +93 -0
  22. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/LICENSE +0 -0
  23. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/__init__.py +0 -0
  24. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/__init__.py +0 -0
  25. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/_version.py +0 -0
  26. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +0 -0
  27. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +0 -0
  28. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +0 -0
  29. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +0 -0
  30. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +0 -0
  31. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/__init__.py +0 -0
  32. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/inference.py +0 -0
  33. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/inference_test.py +0 -0
  34. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
  35. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +0 -0
  36. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +0 -0
  37. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +0 -0
  38. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +0 -0
  39. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +0 -0
  40. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +0 -0
  41. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +0 -0
  42. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +0 -0
  43. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/initial.py +0 -0
  44. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +0 -0
  45. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +0 -0
  46. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +0 -0
  47. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +0 -0
  48. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +0 -0
  49. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +0 -0
  50. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +0 -0
  51. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +0 -0
  52. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +0 -0
  53. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +0 -0
  54. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +0 -0
  55. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +0 -0
  56. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/models.py +0 -0
  57. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +0 -0
  58. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +0 -0
  59. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +0 -0
  60. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +0 -0
  61. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +0 -0
  62. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +0 -0
  63. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +0 -0
  64. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +0 -0
  65. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +0 -0
  66. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/parameters.py +0 -0
  67. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/parameters_test.py +0 -0
  68. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/__init__.py +0 -0
  69. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/inference.py +0 -0
  70. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/inference_test.py +0 -0
  71. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +0 -0
  72. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/slds/models.py +0 -0
  73. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/ssm.py +0 -0
  74. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/types.py +0 -0
  75. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/__init__.py +0 -0
  76. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/bijectors.py +0 -0
  77. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/distributions.py +0 -0
  78. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/distributions_test.py +0 -0
  79. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/optimize.py +0 -0
  80. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/plotting.py +0 -0
  81. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/utils.py +0 -0
  82. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/utils/utils_test.py +0 -0
  83. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/dynamax/warnings.py +0 -0
  84. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/__init__.py +0 -0
  85. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +0 -0
  86. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +0 -0
  87. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +0 -0
  88. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +0 -0
  89. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +0 -0
  90. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +0 -0
  91. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/ssm_temissions.py +0 -0
  92. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/__init__.py +0 -0
  93. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/data_driven_models.py +0 -0
  94. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/data_generator.py +0 -0
  95. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/debug_utils.py +0 -0
  96. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/diffrax_utils.py +0 -0
  97. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/evaluation_utils.py +0 -0
  98. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/experiment_utils.py +0 -0
  99. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/likelihood_eval_utils.py +0 -0
  100. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/optimize_utils.py +0 -0
  101. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/physics_based_models.py +0 -0
  102. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/plotting_chaos_utils.py +0 -0
  103. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/plotting_utils.py +0 -0
  104. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/prior_utils.py +0 -0
  105. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/simulation_utils.py +0 -0
  106. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax/src/utils/test_utils.py +0 -0
  107. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/dependency_links.txt +0 -0
  108. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/cd_dynamax.egg-info/top_level.txt +0 -0
  109. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/setup.cfg +0 -0
  110. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_cdlgssm_dlgssm_match.py +0 -0
  111. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_cdnonlinear_cdlinear_match.py +0 -0
  112. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_filter_forecast_emissions.py +0 -0
  113. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_imports.py +0 -0
  114. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_models.py +0 -0
  115. {cd_dynamax-0.3.2 → cd_dynamax-0.3.4}/tests/test_utils_imports.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cd_dynamax
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Continuous-discrete dynamical systems with JAX and related libraries.
5
5
  Author: Matthew Levine, Iñigo Urteaga
6
6
  Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
32
32
  Requires-Dist: fastprogress>=1.0.0
33
33
  Requires-Dist: graphviz
34
34
  Requires-Dist: ipykernel
35
+ Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
35
36
  Provides-Extra: dev
36
37
  Requires-Dist: pytest>=8.0; extra == "dev"
37
38
  Requires-Dist: ruff; extra == "dev"
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
69
70
 
70
71
  where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
71
72
 
72
- We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
74
 
74
75
  Note:
75
76
 
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
90
91
 
91
92
  All of these problems are deeply interconnected.
92
93
 
93
- - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
94
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
94
95
 
95
96
  - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
96
97
  - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
113
114
 
114
115
  - The codebase is organized into several key directories:
115
116
  ```
116
- cd_dynamax/
117
- ├── src/ # Source code for cd-dynamax library
118
- │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
119
- │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
120
- │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
121
- └── utils/ # Utility functions and example models
122
- ├── dynamax/ # Original dynamax library (as a submodule)
123
- demos/ # Python demos showcasing cd-dynamax functionality
124
- ├── python/scripts/ # Python scripts for running demos
125
- ├── python/notebooks/ # Jupyter notebooks for interactive demos
126
- ├── python/configs/ # Configuration files for demos
127
- tests/ # Tests for cd-dynamax functionality
117
+ .
118
+ ├── cd_dynamax/ # Source code for cd-dynamax library
119
+ │ ├── src/ # Core source code
120
+ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
121
+ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
122
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
123
+ │ │ └── utils/ # Utility functions and example models
124
+ │ └── dynamax/ # Original dynamax library (as a submodule)
125
+ ├── demos/ # Python demos showcasing cd-dynamax functionality
126
+ ├── python/scripts/ # Python scripts for running demos
127
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
128
+ │ └── python/configs/ # Configuration files for demos
129
+ └── tests/ # Tests for cd-dynamax functionality
128
130
  ```
129
131
 
130
132
  ## [Demos](./demos/python)
@@ -155,7 +157,7 @@ make test
155
157
  ```
156
158
 
157
159
  - For linting, we use `ruff`:
158
- ```bashbash
160
+ ```bash
159
161
  make lint
160
162
  ```
161
163
 
@@ -171,40 +173,57 @@ make build_docs
171
173
 
172
174
  # Installation
173
175
 
174
- We support installation via **Conda** (recommended) or via a standard Python virtual environment.
176
+ Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
175
177
 
176
178
  ---
177
179
 
178
- ### Option 1: Conda (recommended)
180
+ ### Option 1: Install from PyPI (recommended)
179
181
 
180
182
  ```bash
181
- # Create and activate a new environment with Python 3.11
182
- conda create -n cd_dynamax_joss python=3.11
183
- conda activate cd_dynamax_joss
183
+ # Create and activate a virtual environment
184
+ python -m venv .venv # or `uv venv`
185
+ source .venv/bin/activate # on macOS/Linux
186
+ .venv\Scripts\activate # on Windows
184
187
 
185
- # Install your package in editable mode (so local changes are picked up)
186
- pip install -e .[dev]
188
+ # Upgrade pip
189
+ pip install --upgrade pip
190
+
191
+ # Install latest release from PyPI
192
+ pip install cd-dynamax
187
193
  ```
188
194
 
189
- This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
195
+ `cd-dynamax` is currently **not available on Conda Forge**.
190
196
 
191
197
  ---
192
198
 
193
- ### Option 2: Python venv + pip
199
+ ### Option 2: Install from source (editable)
194
200
 
195
201
  ```bash
196
202
  # Create and activate a virtual environment
197
- python -m venv .venv
203
+ python -m venv .venv # or `uv venv`
198
204
  source .venv/bin/activate # on macOS/Linux
199
205
  .venv\Scripts\activate # on Windows
200
206
 
201
207
  # Upgrade pip
202
208
  pip install --upgrade pip
203
209
 
204
- # Install in editable mode
210
+ # Install in editable mode for local development
205
211
  pip install -e .[dev]
206
212
  ```
207
213
 
214
+ ---
215
+
216
+ ### Option 3: Conda environment + pip install
217
+
218
+ ```bash
219
+ # Create and activate a Conda environment with Python 3.11
220
+ conda create -n cd_dynamax python=3.11
221
+ conda activate cd_dynamax
222
+
223
+ # Install latest release from PyPI
224
+ pip install cd-dynamax
225
+ ```
226
+
208
227
  #### GPU support
209
228
  If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
210
229
 
@@ -25,7 +25,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
25
25
 
26
26
  where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
27
27
 
28
- We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
28
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
29
29
 
30
30
  Note:
31
31
 
@@ -46,7 +46,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
46
46
 
47
47
  All of these problems are deeply interconnected.
48
48
 
49
- - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
49
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
50
50
 
51
51
  - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
52
52
  - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
@@ -69,18 +69,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
69
69
 
70
70
  - The codebase is organized into several key directories:
71
71
  ```
72
- cd_dynamax/
73
- ├── src/ # Source code for cd-dynamax library
74
- │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
75
- │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
76
- │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
77
- └── utils/ # Utility functions and example models
78
- ├── dynamax/ # Original dynamax library (as a submodule)
79
- demos/ # Python demos showcasing cd-dynamax functionality
80
- ├── python/scripts/ # Python scripts for running demos
81
- ├── python/notebooks/ # Jupyter notebooks for interactive demos
82
- ├── python/configs/ # Configuration files for demos
83
- tests/ # Tests for cd-dynamax functionality
72
+ .
73
+ ├── cd_dynamax/ # Source code for cd-dynamax library
74
+ │ ├── src/ # Core source code
75
+ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
76
+ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
77
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
78
+ │ │ └── utils/ # Utility functions and example models
79
+ │ └── dynamax/ # Original dynamax library (as a submodule)
80
+ ├── demos/ # Python demos showcasing cd-dynamax functionality
81
+ ├── python/scripts/ # Python scripts for running demos
82
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
83
+ │ └── python/configs/ # Configuration files for demos
84
+ └── tests/ # Tests for cd-dynamax functionality
84
85
  ```
85
86
 
86
87
  ## [Demos](./demos/python)
@@ -111,7 +112,7 @@ make test
111
112
  ```
112
113
 
113
114
  - For linting, we use `ruff`:
114
- ```bashbash
115
+ ```bash
115
116
  make lint
116
117
  ```
117
118
 
@@ -127,40 +128,57 @@ make build_docs
127
128
 
128
129
  # Installation
129
130
 
130
- We support installation via **Conda** (recommended) or via a standard Python virtual environment.
131
+ Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
131
132
 
132
133
  ---
133
134
 
134
- ### Option 1: Conda (recommended)
135
+ ### Option 1: Install from PyPI (recommended)
135
136
 
136
137
  ```bash
137
- # Create and activate a new environment with Python 3.11
138
- conda create -n cd_dynamax_joss python=3.11
139
- conda activate cd_dynamax_joss
138
+ # Create and activate a virtual environment
139
+ python -m venv .venv # or `uv venv`
140
+ source .venv/bin/activate # on macOS/Linux
141
+ .venv\Scripts\activate # on Windows
140
142
 
141
- # Install your package in editable mode (so local changes are picked up)
142
- pip install -e .[dev]
143
+ # Upgrade pip
144
+ pip install --upgrade pip
145
+
146
+ # Install latest release from PyPI
147
+ pip install cd-dynamax
143
148
  ```
144
149
 
145
- This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
150
+ `cd-dynamax` is currently **not available on Conda Forge**.
146
151
 
147
152
  ---
148
153
 
149
- ### Option 2: Python venv + pip
154
+ ### Option 2: Install from source (editable)
150
155
 
151
156
  ```bash
152
157
  # Create and activate a virtual environment
153
- python -m venv .venv
158
+ python -m venv .venv # or `uv venv`
154
159
  source .venv/bin/activate # on macOS/Linux
155
160
  .venv\Scripts\activate # on Windows
156
161
 
157
162
  # Upgrade pip
158
163
  pip install --upgrade pip
159
164
 
160
- # Install in editable mode
165
+ # Install in editable mode for local development
161
166
  pip install -e .[dev]
162
167
  ```
163
168
 
169
+ ---
170
+
171
+ ### Option 3: Conda environment + pip install
172
+
173
+ ```bash
174
+ # Create and activate a Conda environment with Python 3.11
175
+ conda create -n cd_dynamax python=3.11
176
+ conda activate cd_dynamax
177
+
178
+ # Install latest release from PyPI
179
+ pip install cd-dynamax
180
+ ```
181
+
164
182
  #### GPU support
165
183
  If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
166
184
 
@@ -11,7 +11,7 @@ from tensorflow_probability.substrates.jax.distributions import (
11
11
 
12
12
  from jax.tree_util import tree_map
13
13
  from jaxtyping import Array, Float
14
- from typing import NamedTuple, Optional, Union, Tuple
14
+ from typing import Collection, NamedTuple, Optional, Union, Tuple
15
15
  from ..utils.utils import psd_solve, symmetrize
16
16
  from ..parameters import ParameterProperties
17
17
  from ..types import PRNGKey, Scalar
@@ -112,9 +112,27 @@ class ParamsLGSSM(NamedTuple):
112
112
  class PosteriorGSSMFiltered(NamedTuple):
113
113
  r"""Marginals of the Gaussian filtering posterior.
114
114
 
115
+ Time indexing convention:
116
+ `filtered_*[t]` corresponds to $x_{t \mid t}$,
117
+ `predicted_*[t]` corresponds to $x_{t+1 \mid t}$,
118
+ and predictive observation quantities such as `y_pred_*[t]` and
119
+ `y_obs_pred_*[t]` correspond to the observation-time prior
120
+ $p(y_t \mid y_{1:t-1})$, i.e. they are computed from $x_{t \mid t-1}$.
121
+
115
122
  :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$
116
- :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$
117
- :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$
123
+ :param filtered_means: array of filtered means $\mathbb{E}[x_t \mid y_{1:t}, u_{1:t}]$
124
+ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[x_t \mid y_{1:t}, u_{1:t}]$
125
+ :param predicted_means: optional array of one-step-ahead state means $\mathbb{E}[x_{t+1} \mid y_{1:t}, u_{1:t}]$
126
+ :param predicted_covariances: optional array of one-step-ahead state covariances $\mathrm{Cov}[x_{t+1} \mid y_{1:t}, u_{1:t}]$
127
+ :param filtered_ensembles: When filtering with EnKF, optional array of filtered state ensembles approximating $p(x_t \mid y_{1:t}, u_{1:t})$
128
+ :param predicted_ensembles: When filtering with EnKF, optional array of one-step-ahead predicted state ensembles approximating $p(x_{t+1} \mid y_{1:t}, u_{1:t})$
129
+ :param y_pred_mean: optional array of predictive emission means for $p(h(x_t) \mid y_{1:t-1}, u_{1:t})$, i.e., before adding observation noise to h(x_t)
130
+ :param y_pred_cov: optional array of predictive emission covariances for $p(h(x_t) \mid y_{1:t-1}, u_{1:t})$, i.e., before adding observation noise to h(x_t)
131
+ :param y_obs_pred_mean: optional array of predictive observation means for $p(y_t \mid y_{1:t-1}, u_{1:t})$
132
+ :param y_obs_pred_cov: optional array of predictive observation covariances for $p(y_t \mid y_{1:t-1}, u_{1:t})$, including observation noise
133
+ :param y_ens_pred: When filtering with EnKF, optional array of predictive emission ensembles for observation time $t$, obtained from the predicted state ensemble for $x_{t \mid t-1}$
134
+ :param y_obs_ens_pred: When filtering with EnKF, optional array of predictive observation ensembles for observation time $t$, obtained by adding sampled observation noise to `y_ens_pred`
135
+ :param posterior_extras: When filtering with EnKF, optional dictionary of filter-specific per-step diagnostics that do not fit the standard filtered posterior schema
118
136
 
119
137
  """
120
138
  # Default attributes
@@ -123,9 +141,64 @@ class PosteriorGSSMFiltered(NamedTuple):
123
141
  filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
124
142
  predicted_means: Optional[Float[Array, "ntime state_dim"]] = None
125
143
  predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
144
+ filtered_ensembles: Optional[Float[Array, "ntime ensemble_size state_dim"]] = None
145
+ predicted_ensembles: Optional[Float[Array, "ntime ensemble_size state_dim"]] = None
146
+ y_pred_mean: Optional[Float[Array, "ntime emission_dim"]] = None
147
+ y_pred_cov: Optional[Float[Array, "ntime emission_dim emission_dim"]] = None
148
+ y_obs_pred_mean: Optional[Float[Array, "ntime emission_dim"]] = None
149
+ y_obs_pred_cov: Optional[Float[Array, "ntime emission_dim emission_dim"]] = None
150
+ y_ens_pred: Optional[Float[Array, "ntime ensemble_size emission_dim"]] = None
151
+ y_obs_ens_pred: Optional[Float[Array, "ntime ensemble_size emission_dim"]] = None
126
152
  # Additional extras
127
153
  posterior_extras: Optional[dict] = None
128
154
 
155
+
156
+ FILTERED_POSTERIOR_FIELD_NAMES = (
157
+ "marginal_loglik",
158
+ "filtered_means",
159
+ "filtered_covariances",
160
+ "predicted_means",
161
+ "predicted_covariances",
162
+ "filtered_ensembles",
163
+ "predicted_ensembles",
164
+ "y_pred_mean",
165
+ "y_pred_cov",
166
+ "y_obs_pred_mean",
167
+ "y_obs_pred_cov",
168
+ "y_ens_pred",
169
+ "y_obs_ens_pred",
170
+ "posterior_extras",
171
+ )
172
+
173
+
174
+ def validate_filtered_posterior_output_fields(
175
+ filter_name: str,
176
+ output_fields: Optional[Collection[str]],
177
+ supported_fields: Collection[str],
178
+ ):
179
+ """Validate requested top-level filtered posterior fields."""
180
+
181
+ if output_fields is None:
182
+ return
183
+
184
+ valid_fields = set(FILTERED_POSTERIOR_FIELD_NAMES)
185
+ supported_fields = set(supported_fields)
186
+ requested_fields = list(output_fields)
187
+
188
+ unknown_fields = sorted(set(requested_fields) - valid_fields)
189
+ if unknown_fields:
190
+ raise ValueError(
191
+ f"Unknown output_fields for {filter_name}: {unknown_fields}. "
192
+ f"Valid top-level filtered posterior fields are: {sorted(valid_fields)}."
193
+ )
194
+
195
+ unsupported_fields = sorted(set(requested_fields) - supported_fields)
196
+ if unsupported_fields:
197
+ raise ValueError(
198
+ f"output_fields {unsupported_fields} are not available for {filter_name}. "
199
+ f"Available fields are: {sorted(supported_fields)}."
200
+ )
201
+
129
202
  class PosteriorGSSMSmoothed(NamedTuple):
130
203
  r"""Marginals of the Gaussian filtering and smoothing posterior.
131
204
 
@@ -101,8 +101,12 @@ def extended_kalman_filter(
101
101
  num_iter: number of linearizations around posterior for update step (default 1).
102
102
  inputs: optional array of inputs.
103
103
  output_fields: list of fields to return in posterior object.
104
- These can take the values "filtered_means", "filtered_covariances",
105
- "predicted_means", "predicted_covariances", and "marginal_loglik".
104
+ Options:
105
+ `"filtered_means"` (default)
106
+ `"filtered_covariances"` (default)
107
+ `"predicted_means"` (default)
108
+ `"predicted_covariances"` (default)
109
+ `"marginal_loglik"`
106
110
 
107
111
  Returns:
108
112
  post: posterior object.
@@ -151,6 +151,13 @@ def unscented_kalman_filter(
151
151
  emissions: array of observations.
152
152
  hyperparams: hyper-parameters.
153
153
  inputs: optional array of inputs.
154
+ output_fields: list of fields to return in posterior object.
155
+ Options:
156
+ `"filtered_means"` (default)
157
+ `"filtered_covariances"` (default)
158
+ `"predicted_means"` (default)
159
+ `"predicted_covariances"` (default)
160
+ `"marginal_loglik"`
154
161
 
155
162
  Returns:
156
163
  filtered_posterior: posterior object.
@@ -27,6 +27,7 @@ from cd_dynamax.dynamax.utils.utils import psd_solve
27
27
  from cd_dynamax.dynamax.linear_gaussian_ssm.inference import (
28
28
  PosteriorGSSMFiltered,
29
29
  PosteriorGSSMSmoothed,
30
+ validate_filtered_posterior_output_fields,
30
31
  )
31
32
 
32
33
  # Initial and emission parameter classes are equivalent
@@ -54,6 +55,19 @@ tfb = tfp.bijectors
54
55
  DEBUG = False
55
56
 
56
57
 
58
+ CDLGSSM_FILTER_OUTPUT_FIELDS = (
59
+ "marginal_loglik",
60
+ "filtered_means",
61
+ "filtered_covariances",
62
+ "predicted_means",
63
+ "predicted_covariances",
64
+ "y_pred_mean",
65
+ "y_pred_cov",
66
+ "y_obs_pred_mean",
67
+ "y_obs_pred_cov",
68
+ )
69
+
70
+
57
71
  #### Helper functions
58
72
  # Helper functions --- modified from dynamax
59
73
  def _get_params(x, dim, t):
@@ -137,20 +151,27 @@ def preprocess_args(f):
137
151
  filter_hyperparams = bound_args.arguments["filter_hyperparams"]
138
152
  inputs = bound_args.arguments["inputs"]
139
153
  warn = bound_args.arguments["warn"]
154
+ output_fields = bound_args.arguments.get("output_fields")
140
155
 
141
156
  num_timesteps = len(emissions)
142
157
  full_params, inputs = preprocess_params_and_inputs(
143
158
  params, num_timesteps, inputs
144
159
  )
145
160
 
146
- return f(
147
- full_params,
148
- emissions,
149
- t_emissions,
161
+ call_kwargs = dict(
162
+ params=full_params,
163
+ emissions=emissions,
164
+ t_emissions=t_emissions,
150
165
  filter_hyperparams=filter_hyperparams,
151
166
  inputs=inputs,
152
167
  warn=warn,
153
168
  )
169
+ if output_fields is not None:
170
+ call_kwargs["output_fields"] = output_fields
171
+
172
+ return f(
173
+ **call_kwargs,
174
+ )
154
175
 
155
176
  return wrapper
156
177
 
@@ -334,6 +355,14 @@ def _condition_on(m, P, H, D, d, R, u, y, warn: bool = True):
334
355
  return mu_cond, Sigma_cond
335
356
 
336
357
 
358
+ def _emission_predicted_moments(m, P, H, D, d, u, warn: bool = True):
359
+ """Compute the predictive emission moments under the linear Gaussian model."""
360
+
361
+ y_pred_mean = H @ m + D @ u + d
362
+ y_pred_cov = psd(H @ P @ H.T, warn=warn)
363
+ return y_pred_mean, y_pred_cov
364
+
365
+
337
366
  # CD-LGSSM filtering implementation: Kalman filter
338
367
  @preprocess_args
339
368
  def cdlgssm_filter(
@@ -342,6 +371,12 @@ def cdlgssm_filter(
342
371
  t_emissions: Optional[Float[Array, "num_timesteps 1"]] = None,
343
372
  filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(),
344
373
  inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None,
374
+ output_fields: Optional[List[str]] = [
375
+ "filtered_means",
376
+ "filtered_covariances",
377
+ "predicted_means",
378
+ "predicted_covariances",
379
+ ],
345
380
  warn: bool = True,
346
381
  ) -> PosteriorGSSMFiltered:
347
382
  r"""Run a Continuous Discrete Kalman filter
@@ -353,6 +388,17 @@ def cdlgssm_filter(
353
388
  t_emissions: continuous-time specific time instants of observations: if not None, it is an array
354
389
  filter_hyperparams: hyperparameters for the filter.
355
390
  inputs: optional array of inputs.
391
+ output_fields: list of top-level posterior fields to return.
392
+ Options:
393
+ `"filtered_means"` (default)
394
+ `"filtered_covariances"` (default)
395
+ `"predicted_means"` (default)
396
+ `"predicted_covariances"` (default)
397
+ `"marginal_loglik"`
398
+ `"y_pred_mean"`
399
+ `"y_pred_cov"`
400
+ `"y_obs_pred_mean"`
401
+ `"y_obs_pred_cov"`
356
402
  warn: whether to issue warnings during filtering (e.g., PSD issues).
357
403
 
358
404
  Returns:
@@ -362,6 +408,10 @@ def cdlgssm_filter(
362
408
  if filter_hyperparams is None:
363
409
  filter_hyperparams = KFHyperParams()
364
410
 
411
+ validate_filtered_posterior_output_fields(
412
+ "cdlgssm_filter", output_fields, CDLGSSM_FILTER_OUTPUT_FIELDS
413
+ )
414
+
365
415
  # Figure out timestamps, as vectors to scan over
366
416
  # t_emissions is of shape num_timesteps \times 1
367
417
  # t0 and t1 are num_timesteps \times 0
@@ -405,8 +455,14 @@ def cdlgssm_filter(
405
455
  u = inputs[t0_idx]
406
456
  y = emissions[t0_idx]
407
457
 
458
+ y_pred_mean, y_pred_cov = _emission_predicted_moments(
459
+ pred_mean, pred_cov, H, D, d, u, warn=warn
460
+ )
461
+ y_obs_pred_mean = y_pred_mean
462
+ y_obs_pred_cov = psd(y_pred_cov + R, warn=warn)
463
+
408
464
  # Update the log likelihood
409
- ll += MVN(H @ pred_mean + D @ u + d, H @ pred_cov @ H.T + R).log_prob(y)
465
+ ll += MVN(y_obs_pred_mean, y_obs_pred_cov).log_prob(y)
410
466
 
411
467
  # Condition on this emission
412
468
  filtered_mean, filtered_cov = _condition_on(
@@ -427,30 +483,30 @@ def cdlgssm_filter(
427
483
  filtered_mean, filtered_cov, F, C, B, b, Q, u, warn=warn
428
484
  )
429
485
 
486
+ outputs = {
487
+ "filtered_means": filtered_mean,
488
+ "filtered_covariances": filtered_cov,
489
+ "predicted_means": pred_mean,
490
+ "predicted_covariances": pred_cov,
491
+ "y_pred_mean": y_pred_mean,
492
+ "y_pred_cov": y_pred_cov,
493
+ "y_obs_pred_mean": y_obs_pred_mean,
494
+ "y_obs_pred_cov": y_obs_pred_cov,
495
+ }
496
+ outputs = {key: val for key, val in outputs.items() if key in output_fields}
497
+
430
498
  # Return the carry and outputs
431
- return (ll, pred_mean, pred_cov), (
432
- filtered_mean,
433
- filtered_cov,
434
- pred_mean,
435
- pred_cov,
436
- )
499
+ return (ll, pred_mean, pred_cov), outputs
437
500
 
438
501
  # The Kalman filter
439
502
  # Initial carry
440
503
  carry = (0.0, params.initial.mean, params.initial.cov)
441
504
  # Scan over all time steps
442
- (ll, _, _), (filtered_means, filtered_covs, pred_means, pred_covs) = lax.scan(
443
- _step, carry, (t0, t1, t0_idx)
444
- )
505
+ (ll, _, _), outputs = lax.scan(_step, carry, (t0, t1, t0_idx))
506
+ outputs = {"marginal_loglik": ll, **outputs}
445
507
 
446
508
  # Return the posterior object
447
- return PosteriorGSSMFiltered(
448
- marginal_loglik=ll,
449
- filtered_means=filtered_means,
450
- filtered_covariances=filtered_covs,
451
- predicted_means=pred_means,
452
- predicted_covariances=pred_covs,
453
- )
509
+ return PosteriorGSSMFiltered(**outputs)
454
510
 
455
511
 
456
512
  # Kalmam Smoothing equations, in continuous-time
@@ -1110,12 +1166,10 @@ def cdlgssm_forecast(
1110
1166
  inputs: optional array of inputs, of shape (1 + num_timesteps) \times input_dim
1111
1167
  - The extra input is needed for the initial emission, i.e., it should be at time t_init
1112
1168
  output_fields: list of fields to return in posterior object.
1113
- These can take the values
1114
- If we forecast Gaussian distributions, based on filtering methods
1115
- "forecasted_state_means",
1116
- "forecasted_state_covariances",
1117
- If we forecast paths, based on solving the SDE
1118
- "forecasted_state_path".
1169
+ Options:
1170
+ `"forecasted_state_means"` (default for Gaussian forecasts)
1171
+ `"forecasted_state_covariances"` (default for Gaussian forecasts)
1172
+ `"forecasted_state_path"` (for point-initialized path forecasts)
1119
1173
  key: random key (e.g., for Ensemble Kalman).
1120
1174
  diffeqsolve_settings: settings for the SDE solver
1121
1175
  warn: whether to issue warnings during filtering (e.g., PSD issues).
@@ -4,7 +4,7 @@ import jax.random as jr
4
4
  from jaxtyping import Array, Float
5
5
 
6
6
  # Type annotations
7
- from typing import Any, Optional, Tuple, Union
7
+ from typing import Any, List, Optional, Tuple, Union
8
8
  from typing_extensions import Protocol
9
9
 
10
10
  # Distributions, compatible with JAX, from TensorFlow Probability
@@ -510,6 +510,12 @@ class ContDiscreteLinearGaussianSSM(SSM):
510
510
  t_emissions: Optional[Float[Array, "ntime 1"]] = None,
511
511
  filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(),
512
512
  inputs: Optional[Float[Array, "ntime input_dim"]] = None,
513
+ output_fields: Optional[List[str]] = [
514
+ "filtered_means",
515
+ "filtered_covariances",
516
+ "predicted_means",
517
+ "predicted_covariances",
518
+ ],
513
519
  warn: bool = True,
514
520
  ) -> PosteriorGSSMFiltered:
515
521
  r"""Run the CD-Kalman filter to compute the filtered posterior distribution over states.
@@ -519,6 +525,20 @@ class ContDiscreteLinearGaussianSSM(SSM):
519
525
  t_emissions: continuous-time specific time instants of observations: if not None, it is an array
520
526
  filter_hyperparams: hyperparameters for the Kalman filter.
521
527
  inputs: optional sequence of inputs.
528
+ output_fields: Which top-level posterior fields to return from the
529
+ filter. Options:
530
+ `"filtered_means"` (default)
531
+ `"filtered_covariances"` (default)
532
+ `"predicted_means"` (default)
533
+ `"predicted_covariances"` (default)
534
+ `"marginal_loglik"`
535
+ `"y_pred_mean"`
536
+ `"y_pred_cov"`
537
+ `"y_obs_pred_mean"`
538
+ `"y_obs_pred_cov"`
539
+ Predictive emission fields are available directly as
540
+ top-level filtered posterior outputs. Unsupported fields raise
541
+ a `ValueError`.
522
542
  warn: whether to warn about numerical issues.
523
543
  Returns:
524
544
  filtered posterior distribution over states.
@@ -526,7 +546,13 @@ class ContDiscreteLinearGaussianSSM(SSM):
526
546
 
527
547
  # Directly run the CD-Kalman filter
528
548
  return cdlgssm_filter(
529
- params, emissions, t_emissions, filter_hyperparams, inputs, warn=warn
549
+ params,
550
+ emissions,
551
+ t_emissions,
552
+ filter_hyperparams,
553
+ inputs,
554
+ output_fields=output_fields,
555
+ warn=warn,
530
556
  )
531
557
 
532
558
  # High-level, user-friendly interface combining filtering and forecasting steps