cd-dynamax 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. cd_dynamax-0.2.5/LICENSE +21 -0
  2. cd_dynamax-0.2.5/PKG-INFO +228 -0
  3. cd_dynamax-0.2.5/README.md +185 -0
  4. cd_dynamax-0.2.5/cd_dynamax/__init__.py +82 -0
  5. cd_dynamax-0.2.5/cd_dynamax/dynamax/__init__.py +9 -0
  6. cd_dynamax-0.2.5/cd_dynamax/dynamax/_version.py +658 -0
  7. cd_dynamax-0.2.5/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +6 -0
  8. cd_dynamax-0.2.5/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +386 -0
  9. cd_dynamax-0.2.5/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +81 -0
  10. cd_dynamax-0.2.5/cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +131 -0
  11. cd_dynamax-0.2.5/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +101 -0
  12. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/__init__.py +27 -0
  13. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/inference.py +629 -0
  14. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/inference_test.py +316 -0
  15. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
  16. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +706 -0
  17. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +231 -0
  18. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +163 -0
  19. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +170 -0
  20. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +177 -0
  21. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +144 -0
  22. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +1031 -0
  23. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +500 -0
  24. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/initial.py +73 -0
  25. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +221 -0
  26. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +175 -0
  27. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +159 -0
  28. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +164 -0
  29. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +165 -0
  30. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +85 -0
  31. cd_dynamax-0.2.5/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +194 -0
  32. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +22 -0
  33. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +105 -0
  34. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +621 -0
  35. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +252 -0
  36. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +418 -0
  37. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +172 -0
  38. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/models.py +615 -0
  39. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +24 -0
  40. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +382 -0
  41. cd_dynamax-0.2.5/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +193 -0
  42. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +9 -0
  43. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +349 -0
  44. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +117 -0
  45. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +178 -0
  46. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +291 -0
  47. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +33 -0
  48. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +116 -0
  49. cd_dynamax-0.2.5/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +203 -0
  50. cd_dynamax-0.2.5/cd_dynamax/dynamax/parameters.py +125 -0
  51. cd_dynamax-0.2.5/cd_dynamax/dynamax/parameters_test.py +135 -0
  52. cd_dynamax-0.2.5/cd_dynamax/dynamax/slds/__init__.py +2 -0
  53. cd_dynamax-0.2.5/cd_dynamax/dynamax/slds/inference.py +339 -0
  54. cd_dynamax-0.2.5/cd_dynamax/dynamax/slds/inference_test.py +124 -0
  55. cd_dynamax-0.2.5/cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +161 -0
  56. cd_dynamax-0.2.5/cd_dynamax/dynamax/slds/models.py +133 -0
  57. cd_dynamax-0.2.5/cd_dynamax/dynamax/ssm.py +471 -0
  58. cd_dynamax-0.2.5/cd_dynamax/dynamax/types.py +10 -0
  59. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/__init__.py +0 -0
  60. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/bijectors.py +34 -0
  61. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/distributions.py +428 -0
  62. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/distributions_test.py +160 -0
  63. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/optimize.py +111 -0
  64. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/plotting.py +151 -0
  65. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/utils.py +276 -0
  66. cd_dynamax-0.2.5/cd_dynamax/dynamax/utils/utils_test.py +43 -0
  67. cd_dynamax-0.2.5/cd_dynamax/dynamax/warnings.py +18 -0
  68. cd_dynamax-0.2.5/cd_dynamax/src/__init__.py +54 -0
  69. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +25 -0
  70. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +280 -0
  71. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +1340 -0
  72. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +712 -0
  73. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +29 -0
  74. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +488 -0
  75. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +1052 -0
  76. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +721 -0
  77. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +684 -0
  78. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +1462 -0
  79. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +25 -0
  80. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +316 -0
  81. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +408 -0
  82. cd_dynamax-0.2.5/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +938 -0
  83. cd_dynamax-0.2.5/cd_dynamax/src/ssm_temissions.py +1617 -0
  84. cd_dynamax-0.2.5/cd_dynamax/src/utils/__init__.py +6 -0
  85. cd_dynamax-0.2.5/cd_dynamax/src/utils/data_driven_models.py +402 -0
  86. cd_dynamax-0.2.5/cd_dynamax/src/utils/data_generator.py +173 -0
  87. cd_dynamax-0.2.5/cd_dynamax/src/utils/debug_utils.py +144 -0
  88. cd_dynamax-0.2.5/cd_dynamax/src/utils/demo_utils.py +846 -0
  89. cd_dynamax-0.2.5/cd_dynamax/src/utils/diffrax_utils.py +224 -0
  90. cd_dynamax-0.2.5/cd_dynamax/src/utils/evaluation_utils.py +18 -0
  91. cd_dynamax-0.2.5/cd_dynamax/src/utils/experiment_utils.py +498 -0
  92. cd_dynamax-0.2.5/cd_dynamax/src/utils/likelihood_eval_utils.py +638 -0
  93. cd_dynamax-0.2.5/cd_dynamax/src/utils/optimize_utils.py +161 -0
  94. cd_dynamax-0.2.5/cd_dynamax/src/utils/physics_based_models.py +322 -0
  95. cd_dynamax-0.2.5/cd_dynamax/src/utils/plotting_chaos_utils.py +358 -0
  96. cd_dynamax-0.2.5/cd_dynamax/src/utils/plotting_utils.py +1604 -0
  97. cd_dynamax-0.2.5/cd_dynamax/src/utils/prior_utils.py +220 -0
  98. cd_dynamax-0.2.5/cd_dynamax/src/utils/simulation_utils.py +492 -0
  99. cd_dynamax-0.2.5/cd_dynamax/src/utils/test_utils.py +225 -0
  100. cd_dynamax-0.2.5/cd_dynamax.egg-info/PKG-INFO +228 -0
  101. cd_dynamax-0.2.5/cd_dynamax.egg-info/SOURCES.txt +110 -0
  102. cd_dynamax-0.2.5/cd_dynamax.egg-info/dependency_links.txt +1 -0
  103. cd_dynamax-0.2.5/cd_dynamax.egg-info/requires.txt +33 -0
  104. cd_dynamax-0.2.5/cd_dynamax.egg-info/top_level.txt +1 -0
  105. cd_dynamax-0.2.5/pyproject.toml +63 -0
  106. cd_dynamax-0.2.5/setup.cfg +4 -0
  107. cd_dynamax-0.2.5/tests/test_cdlgssm_dlgssm_match.py +509 -0
  108. cd_dynamax-0.2.5/tests/test_cdnonlinear_cdlinear_match.py +1027 -0
  109. cd_dynamax-0.2.5/tests/test_filter_forecast_emissions.py +393 -0
  110. cd_dynamax-0.2.5/tests/test_imports.py +11 -0
  111. cd_dynamax-0.2.5/tests/test_models.py +185 -0
  112. cd_dynamax-0.2.5/tests/test_utils_imports.py +6 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) Iñigo Urteaga and Matthew Levine
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,228 @@
1
+ Metadata-Version: 2.4
2
+ Name: cd_dynamax
3
+ Version: 0.2.5
4
+ Summary: Continuous-discrete dynamical systems with JAX and related libraries.
5
+ Author: Matthew Levine, Iñigo Urteaga
6
+ Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
7
+ Requires-Python: >=3.11
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy>=2.0
11
+ Requires-Dist: scipy>=1.13
12
+ Requires-Dist: pandas>=2.2
13
+ Requires-Dist: scikit-learn>=1.5
14
+ Requires-Dist: matplotlib>=3.9
15
+ Requires-Dist: seaborn>=0.13
16
+ Requires-Dist: pillow>=10
17
+ Requires-Dist: tqdm>=4.66
18
+ Requires-Dist: jax>=0.6.2
19
+ Requires-Dist: jaxlib>=0.6.2
20
+ Requires-Dist: optax>=0.2.2
21
+ Requires-Dist: equinox>=0.13
22
+ Requires-Dist: diffrax==0.6.2
23
+ Requires-Dist: lineax>=0.0.8
24
+ Requires-Dist: flax>=0.11
25
+ Requires-Dist: blackjax>=1.2
26
+ Requires-Dist: jaxopt>=0.8
27
+ Requires-Dist: chex>=0.1.88
28
+ Requires-Dist: distrax>=0.1.3
29
+ Requires-Dist: orbax-checkpoint>=0.9
30
+ Requires-Dist: etils>=1.5
31
+ Requires-Dist: dm-tree>=0.1.8
32
+ Requires-Dist: graphviz
33
+ Requires-Dist: ipykernel
34
+ Provides-Extra: dev
35
+ Requires-Dist: pytest>=8.0; extra == "dev"
36
+ Requires-Dist: ruff; extra == "dev"
37
+ Requires-Dist: mkdocs-material>=9.7.1; extra == "dev"
38
+ Requires-Dist: mkdocs<2,>=1.6.1; extra == "dev"
39
+ Requires-Dist: mypy>=1.19.1; extra == "dev"
40
+ Requires-Dist: mkdocs-jupyter>=0.25.1; extra == "dev"
41
+ Requires-Dist: mkdocstrings-python; extra == "dev"
42
+ Dynamic: license-file
43
+
44
+ # Overview of cd-dynamax
45
+
46
+ The primary goal of this codebase is to extend [dynamax](https://github.com/probml/dynamax) to a continuous-discrete (CD) state-space-modeling setting, that is, to problems where
47
+
48
+ - the underlying dynamics are continuous in time,
49
+ - and measurements can arise at arbitrary (i.e., non-regular) discrete times.
50
+
51
+ To address these gaps, `cd-dynamax` modifies `dynamax` to accept irregularly sampled data and implements classical algorithms for continuous-discrete filtering and smoothing.
52
+
53
+ ## Mathematical Framework: continuous-discrete state-space models
54
+
55
+ In this repository, we build an expanded toolkit for filtering, forecasting and learning dynamical systems that underpin real-world messy time-series data.
56
+
57
+ We move towards this goal by working with the following flexible mathematical setting:
58
+
59
+ - We assume there exists a (possibly unknown) stochastic dynamical system of form
60
+
61
+ $$dx(t) = f(x(t),t)dt + L(x(t),t) dw(t)$$
62
+
63
+ where $x \in \mathbb{R}^{d_x}$, $x(0) \sim \mathcal{N}(\mu_0, \Sigma_0)$, $f$ a possibly time-dependent drift function, $L$ a possibly state and/or time-dependent diffusion coefficient, and $dw$ is the derivative of a $d_x$-dimensional Brownian motion with a covariance $Q$.
64
+
65
+ - We assume data are available at arbitrary times $\\{t_k\\}_{k=1}^K$ and observed via a measurement process dictated by
66
+
67
+ $$y(t) = h(x(t)) + \eta(t)$$
68
+
69
+ where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
70
+
71
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
72
+
73
+ Note:
74
+
75
+ - We assume $\eta(t)$ i.i.d. w.r.t. $t$:
76
+ - This assumption places us in the *continuous (dynamics) - discrete (observation)* setting.
77
+ - If $\eta(t)$ had temporal correlations, we would likely adopt a mathematical setting that defines the observation process continuously in time via its own SDE.
78
+
79
+ - Other extensions of the above paradigm include categorical state-spaces and non-additive observation noise distributions
80
+ - These can fit into our code framework (indeed, some are covered in `dynamax`), but have not been our focus.
81
+
82
+ ## cd-dynamax goals and approach
83
+
84
+ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to:
85
+ - Filter: estimate $x(t_K) \\ | \\ Y_K, \\ \theta$
86
+ - Smooth: estimate $\\{x(t)\\}_t \\ | \\ Y_K, \\ \theta$
87
+ - Predict: estimate $x(t > t_K)\\ |\\ Y_K, \\ \theta$
88
+ - Infer parameters: estimate $\theta \\ |\\ Y_K$
89
+
90
+ All of these problems are deeply interconnected.
91
+
92
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
93
+
94
+ - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
95
+ - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
96
+
97
+ - We implement such filtering/smoothing algorithms in an efficient, autodifferentiable framework.
98
+ - We enable usage of modern general-purpose tools for parameter inference (e.g., stochastic gradient descent, Hamiltonian Monte Carlo).
99
+
100
+ - In cd-dynamax, we take onto the parameter inference case by relying on marginalizing out unobserved states $\\{x(t)\\}_t$
101
+ - this is a design choice of ours, other alternatives are possible.
102
+ - This marginalization is performed (approximately, in cases of non-linear dynamics) via filtering/smoothing algorithms.
103
+
104
+ ## Codebase description and status
105
+
106
+ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-discrete state space models, where observations are made at specified discrete times rather than at regular intervals.
107
+
108
+ - We leverage [dynamax](https://github.com/probml/dynamax) code
109
+ - Currently, based on a local directory with [Dynamax release 0.1.5](https://github.com/probml/dynamax/releases/tag/0.1.5)
110
+
111
+ - We have implemented the [`cd-dynamax` codebase](./cd_dynamax/README.md) to deal with [continuous-discrete linear and non-linear models](./cd_dynamax/src/README.md), along with several filtering and smoothing algorithms.
112
+
113
+ - The codebase is organized into several key directories:
114
+ ```
115
+ cd_dynamax/
116
+ ├── src/ # Source code for cd-dynamax library
117
+ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
118
+ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
119
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
120
+ │ └── utils/ # Utility functions and example models
121
+ ├── dynamax/ # Original dynamax library (as a submodule)
122
+ demos/ # Python demos showcasing cd-dynamax functionality
123
+ ├── python/scripts/ # Python scripts for running demos
124
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
125
+ ├── python/configs/ # Configuration files for demos
126
+ tests/ # Tests for cd-dynamax functionality
127
+ ```
128
+
129
+ ## [Demos](./demos/python)
130
+
131
+ We provide a set of [demos](./demos/python/README.md) that showcase key functionality of `cd-dynamax`.
132
+
133
+ These [scripts](./demos/python/scripts) and [notebooks](./demos/python/notebooks) illustrate how to learn components of continuous-discrete SDEs from data.
134
+
135
+ For instance:
136
+
137
+ - [Filtering-based likelihood tutorial](./demos/python/notebooks/lorenz63_filter_based_likelihood_tutorial.ipynb) to filtering-based likelihood computation for continuous-discrete SDEs.
138
+
139
+ - [SGD-based model fitting tutorial](./demos/python/notebooks/lorenz63_sgd_fit_to_data_tutorial.ipynb) to SGD-based fitting of continuous-discrete SDE model to data.
140
+
141
+ - [MCMC-based model fitting tutorial](./demos/python/notebooks/lorenz63_mcmc_fit_to_data_tutorial.ipynb) to MCMC-based fitting of continuous-discrete SDE model to data.
142
+
143
+ ## [Tests](./tests)
144
+
145
+ - Several [tests](./tests/README.md) to establish cd-dynamax general functionality, as well as linear and non-linear filters/smoothers tests: e.g., checks that non-linear algorithms applied to linear problems return similar results as linear algorithms.
146
+
147
+ ## [Makefile](./Makefile)
148
+
149
+ - We provide a [Makefile](./Makefile) to automate common tasks, such as running tests and demos.
150
+
151
+ - To run all tests, simply execute:
152
+ ```bash
153
+ make test
154
+ ```
155
+
156
+ - For linting, we use `ruff`:
157
+ ```bashbash
158
+ make lint
159
+ ```
160
+
161
+ - We can also format files using `ruff`:
162
+ ```bash
163
+ make clean
164
+ ```
165
+
166
+ - The docs can be built using `mkdocs` as:
167
+ ```bash
168
+ make build_docs
169
+ ```
170
+
171
+ # Installation
172
+
173
+ We support installation via **Conda** (recommended) or via a standard Python virtual environment.
174
+
175
+ ---
176
+
177
+ ### Option 1: Conda (recommended)
178
+
179
+ ```bash
180
+ # Create and activate a new environment with Python 3.11
181
+ conda create -n cd_dynamax_joss python=3.11
182
+ conda activate cd_dynamax_joss
183
+
184
+ # Install your package in editable mode (so local changes are picked up)
185
+ pip install -e .[dev]
186
+ ```
187
+
188
+ This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
189
+
190
+ ---
191
+
192
+ ### Option 2: Python venv + pip
193
+
194
+ ```bash
195
+ # Create and activate a virtual environment
196
+ python -m venv .venv
197
+ source .venv/bin/activate # on macOS/Linux
198
+ .venv\Scripts\activate # on Windows
199
+
200
+ # Upgrade pip
201
+ pip install --upgrade pip
202
+
203
+ # Install in editable mode
204
+ pip install -e .[dev]
205
+ ```
206
+
207
+ #### GPU support
208
+ If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
209
+
210
+ Check the [JAX installation docs](https://jax.readthedocs.io/en/latest/installation.html#installation) for the exact commands for your system.
211
+
212
+ ---
213
+
214
+ ### Notes
215
+
216
+ - `pip install -e .` puts the repo in *editable mode*, so changes to source code are immediately available without reinstalling.
217
+
218
+ - If you plan to use plotting features that rely on `graphviz`, make sure the system binary is installed:
219
+ - **macOS:** `brew install graphviz`
220
+ - **Ubuntu/Debian:** `sudo apt install graphviz`
221
+ - **Windows (conda):** `conda install graphviz`
222
+
223
+ - The `[dev]` extra installs additional developer tools (like `pytest`).
224
+ - Once your environment is installed, you can run automated tests:
225
+ ```bash
226
+ pytest
227
+ ```
228
+
@@ -0,0 +1,185 @@
1
+ # Overview of cd-dynamax
2
+
3
+ The primary goal of this codebase is to extend [dynamax](https://github.com/probml/dynamax) to a continuous-discrete (CD) state-space-modeling setting, that is, to problems where
4
+
5
+ - the underlying dynamics are continuous in time,
6
+ - and measurements can arise at arbitrary (i.e., non-regular) discrete times.
7
+
8
+ To address these gaps, `cd-dynamax` modifies `dynamax` to accept irregularly sampled data and implements classical algorithms for continuous-discrete filtering and smoothing.
9
+
10
+ ## Mathematical Framework: continuous-discrete state-space models
11
+
12
+ In this repository, we build an expanded toolkit for filtering, forecasting and learning dynamical systems that underpin real-world messy time-series data.
13
+
14
+ We move towards this goal by working with the following flexible mathematical setting:
15
+
16
+ - We assume there exists a (possibly unknown) stochastic dynamical system of form
17
+
18
+ $$dx(t) = f(x(t),t)dt + L(x(t),t) dw(t)$$
19
+
20
+ where $x \in \mathbb{R}^{d_x}$, $x(0) \sim \mathcal{N}(\mu_0, \Sigma_0)$, $f$ a possibly time-dependent drift function, $L$ a possibly state and/or time-dependent diffusion coefficient, and $dw$ is the derivative of a $d_x$-dimensional Brownian motion with a covariance $Q$.
21
+
22
+ - We assume data are available at arbitrary times $\\{t_k\\}_{k=1}^K$ and observed via a measurement process dictated by
23
+
24
+ $$y(t) = h(x(t)) + \eta(t)$$
25
+
26
+ where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
27
+
28
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
29
+
30
+ Note:
31
+
32
+ - We assume $\eta(t)$ i.i.d. w.r.t. $t$:
33
+ - This assumption places us in the *continuous (dynamics) - discrete (observation)* setting.
34
+ - If $\eta(t)$ had temporal correlations, we would likely adopt a mathematical setting that defines the observation process continuously in time via its own SDE.
35
+
36
+ - Other extensions of the above paradigm include categorical state-spaces and non-additive observation noise distributions
37
+ - These can fit into our code framework (indeed, some are covered in `dynamax`), but have not been our focus.
38
+
39
+ ## cd-dynamax goals and approach
40
+
41
+ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to:
42
+ - Filter: estimate $x(t_K) \\ | \\ Y_K, \\ \theta$
43
+ - Smooth: estimate $\\{x(t)\\}_t \\ | \\ Y_K, \\ \theta$
44
+ - Predict: estimate $x(t > t_K)\\ |\\ Y_K, \\ \theta$
45
+ - Infer parameters: estimate $\theta \\ |\\ Y_K$
46
+
47
+ All of these problems are deeply interconnected.
48
+
49
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
50
+
51
+ - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
52
+ - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
53
+
54
+ - We implement such filtering/smoothing algorithms in an efficient, autodifferentiable framework.
55
+ - We enable usage of modern general-purpose tools for parameter inference (e.g., stochastic gradient descent, Hamiltonian Monte Carlo).
56
+
57
+ - In cd-dynamax, we take onto the parameter inference case by relying on marginalizing out unobserved states $\\{x(t)\\}_t$
58
+ - this is a design choice of ours, other alternatives are possible.
59
+ - This marginalization is performed (approximately, in cases of non-linear dynamics) via filtering/smoothing algorithms.
60
+
61
+ ## Codebase description and status
62
+
63
+ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-discrete state space models, where observations are made at specified discrete times rather than at regular intervals.
64
+
65
+ - We leverage [dynamax](https://github.com/probml/dynamax) code
66
+ - Currently, based on a local directory with [Dynamax release 0.1.5](https://github.com/probml/dynamax/releases/tag/0.1.5)
67
+
68
+ - We have implemented the [`cd-dynamax` codebase](./cd_dynamax/README.md) to deal with [continuous-discrete linear and non-linear models](./cd_dynamax/src/README.md), along with several filtering and smoothing algorithms.
69
+
70
+ - The codebase is organized into several key directories:
71
+ ```
72
+ cd_dynamax/
73
+ ├── src/ # Source code for cd-dynamax library
74
+ │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
75
+ │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
76
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
77
+ │ └── utils/ # Utility functions and example models
78
+ ├── dynamax/ # Original dynamax library (as a submodule)
79
+ demos/ # Python demos showcasing cd-dynamax functionality
80
+ ├── python/scripts/ # Python scripts for running demos
81
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
82
+ ├── python/configs/ # Configuration files for demos
83
+ tests/ # Tests for cd-dynamax functionality
84
+ ```
85
+
86
+ ## [Demos](./demos/python)
87
+
88
+ We provide a set of [demos](./demos/python/README.md) that showcase key functionality of `cd-dynamax`.
89
+
90
+ These [scripts](./demos/python/scripts) and [notebooks](./demos/python/notebooks) illustrate how to learn components of continuous-discrete SDEs from data.
91
+
92
+ For instance:
93
+
94
+ - [Filtering-based likelihood tutorial](./demos/python/notebooks/lorenz63_filter_based_likelihood_tutorial.ipynb) to filtering-based likelihood computation for continuous-discrete SDEs.
95
+
96
+ - [SGD-based model fitting tutorial](./demos/python/notebooks/lorenz63_sgd_fit_to_data_tutorial.ipynb) to SGD-based fitting of continuous-discrete SDE model to data.
97
+
98
+ - [MCMC-based model fitting tutorial](./demos/python/notebooks/lorenz63_mcmc_fit_to_data_tutorial.ipynb) to MCMC-based fitting of continuous-discrete SDE model to data.
99
+
100
+ ## [Tests](./tests)
101
+
102
+ - Several [tests](./tests/README.md) to establish cd-dynamax general functionality, as well as linear and non-linear filters/smoothers tests: e.g., checks that non-linear algorithms applied to linear problems return similar results as linear algorithms.
103
+
104
+ ## [Makefile](./Makefile)
105
+
106
+ - We provide a [Makefile](./Makefile) to automate common tasks, such as running tests and demos.
107
+
108
+ - To run all tests, simply execute:
109
+ ```bash
110
+ make test
111
+ ```
112
+
113
+ - For linting, we use `ruff`:
114
+ ```bashbash
115
+ make lint
116
+ ```
117
+
118
+ - We can also format files using `ruff`:
119
+ ```bash
120
+ make clean
121
+ ```
122
+
123
+ - The docs can be built using `mkdocs` as:
124
+ ```bash
125
+ make build_docs
126
+ ```
127
+
128
+ # Installation
129
+
130
+ We support installation via **Conda** (recommended) or via a standard Python virtual environment.
131
+
132
+ ---
133
+
134
+ ### Option 1: Conda (recommended)
135
+
136
+ ```bash
137
+ # Create and activate a new environment with Python 3.11
138
+ conda create -n cd_dynamax_joss python=3.11
139
+ conda activate cd_dynamax_joss
140
+
141
+ # Install your package in editable mode (so local changes are picked up)
142
+ pip install -e .[dev]
143
+ ```
144
+
145
+ This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
146
+
147
+ ---
148
+
149
+ ### Option 2: Python venv + pip
150
+
151
+ ```bash
152
+ # Create and activate a virtual environment
153
+ python -m venv .venv
154
+ source .venv/bin/activate # on macOS/Linux
155
+ .venv\Scripts\activate # on Windows
156
+
157
+ # Upgrade pip
158
+ pip install --upgrade pip
159
+
160
+ # Install in editable mode
161
+ pip install -e .[dev]
162
+ ```
163
+
164
+ #### GPU support
165
+ If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
166
+
167
+ Check the [JAX installation docs](https://jax.readthedocs.io/en/latest/installation.html#installation) for the exact commands for your system.
168
+
169
+ ---
170
+
171
+ ### Notes
172
+
173
+ - `pip install -e .` puts the repo in *editable mode*, so changes to source code are immediately available without reinstalling.
174
+
175
+ - If you plan to use plotting features that rely on `graphviz`, make sure the system binary is installed:
176
+ - **macOS:** `brew install graphviz`
177
+ - **Ubuntu/Debian:** `sudo apt install graphviz`
178
+ - **Windows (conda):** `conda install graphviz`
179
+
180
+ - The `[dev]` extra installs additional developer tools (like `pytest`).
181
+ - Once your environment is installed, you can run automated tests:
182
+ ```bash
183
+ pytest
184
+ ```
185
+
@@ -0,0 +1,82 @@
1
+ # cd_dynamax/__init__.py
2
+
3
+ # Nonlinear SSM
4
+ from .src.continuous_discrete_nonlinear_gaussian_ssm import (
5
+ ContDiscreteNonlinearGaussianSSM,
6
+ ParamsCDNLGSSM,
7
+ cdnlgssm_filter,
8
+ cdnlgssm_smoother,
9
+ cdnlgssm_forecast,
10
+ cdnlgssm_emissions,
11
+ EKFHyperParams,
12
+ UKFHyperParams,
13
+ EnKFHyperParams,
14
+ )
15
+
16
+ from .src.continuous_discrete_nonlinear_ssm import (
17
+ ContDiscreteNonlinearSSM,
18
+ ParamsCDNLSSM,
19
+ DPFHyperParams,
20
+ cdnlssm_filter,
21
+ )
22
+
23
+ # Linear SSM
24
+ from .src.continuous_discrete_linear_gaussian_ssm import (
25
+ ContDiscreteLinearGaussianSSM,
26
+ ParamsCDLGSSM,
27
+ cdlgssm_filter,
28
+ cdlgssm_smoother,
29
+ cdlgssm_forecast,
30
+ cdlgssm_emissions,
31
+ cdlgssm_posterior_sample,
32
+ cdlgssm_joint_sample,
33
+ KFHyperParams,
34
+ )
35
+
36
+ # Discrete-Discrete Linear SSM
37
+ from .dynamax.linear_gaussian_ssm import LinearGaussianSSM
38
+
39
+ # Shared pieces
40
+ from .src.ssm_temissions import SSM, Prior
41
+
42
+ # Utilities (the ones your demos use most)
43
+ from .src.utils.diffrax_utils import adjust_rhs
44
+ from .src.utils.optimize_utils import make_optimizer
45
+ from .src.utils.simulation_utils import make_key_sequence
46
+
47
+ __all__ = [
48
+ # Models
49
+ "ContDiscreteNonlinearGaussianSSM",
50
+ "ContDiscreteNonlinearSSM",
51
+ "ContDiscreteLinearGaussianSSM",
52
+ "LinearGaussianSSM",
53
+ # Params
54
+ "ParamsCDNLGSSM",
55
+ "ParamsCDNLSSM",
56
+ "ParamsCDLGSSM",
57
+ # Nonlinear algos
58
+ "cdnlgssm_filter",
59
+ "cdnlgssm_smoother",
60
+ "cdnlgssm_forecast",
61
+ "cdnlgssm_emissions",
62
+ "cdnlssm_filter",
63
+ "EKFHyperParams",
64
+ "UKFHyperParams",
65
+ "EnKFHyperParams",
66
+ "DPFHyperParams",
67
+ # Linear algos
68
+ "cdlgssm_filter",
69
+ "cdlgssm_smoother",
70
+ "cdlgssm_forecast",
71
+ "cdlgssm_emissions",
72
+ "cdlgssm_posterior_sample",
73
+ "cdlgssm_joint_sample",
74
+ "KFHyperParams",
75
+ # SSM/emissions
76
+ "SSM",
77
+ "Prior",
78
+ # Utils
79
+ "adjust_rhs",
80
+ "make_optimizer",
81
+ "make_key_sequence",
82
+ ]
@@ -0,0 +1,9 @@
1
+ from . import _version
2
+ __version__ = _version.get_versions()['version']
3
+
4
+ # Catch expected warnings from TFP
5
+ from . import warnings
6
+
7
+ # Default to float32 matrix multiplication on TPUs and GPUs
8
+ import jax
9
+ jax.config.update('jax_default_matmul_precision', 'float32')