cd-dynamax 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/PKG-INFO +46 -27
  2. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/README.md +44 -26
  3. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +7 -8
  4. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +14 -20
  5. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +4 -4
  6. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +6 -6
  7. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +15 -15
  8. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +4 -4
  9. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/PKG-INFO +46 -27
  10. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/SOURCES.txt +1 -0
  11. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/requires.txt +3 -0
  12. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/pyproject.toml +3 -2
  13. cd_dynamax-0.3.3/tests/test_state_dependent_diffusion.py +93 -0
  14. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/LICENSE +0 -0
  15. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/__init__.py +0 -0
  16. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/__init__.py +0 -0
  17. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/_version.py +0 -0
  18. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +0 -0
  19. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +0 -0
  20. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +0 -0
  21. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +0 -0
  22. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +0 -0
  23. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/__init__.py +0 -0
  24. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/inference.py +0 -0
  25. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/inference_test.py +0 -0
  26. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
  27. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +0 -0
  28. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +0 -0
  29. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +0 -0
  30. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +0 -0
  31. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +0 -0
  32. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +0 -0
  33. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +0 -0
  34. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +0 -0
  35. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/initial.py +0 -0
  36. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +0 -0
  37. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +0 -0
  38. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +0 -0
  39. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +0 -0
  40. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +0 -0
  41. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +0 -0
  42. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +0 -0
  43. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +0 -0
  44. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +0 -0
  45. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +0 -0
  46. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +0 -0
  47. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +0 -0
  48. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +0 -0
  49. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/models.py +0 -0
  50. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +0 -0
  51. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +0 -0
  52. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +0 -0
  53. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +0 -0
  54. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +0 -0
  55. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +0 -0
  56. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +0 -0
  57. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +0 -0
  58. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +0 -0
  59. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +0 -0
  60. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +0 -0
  61. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/parameters.py +0 -0
  62. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/parameters_test.py +0 -0
  63. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/__init__.py +0 -0
  64. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/inference.py +0 -0
  65. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/inference_test.py +0 -0
  66. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +0 -0
  67. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/slds/models.py +0 -0
  68. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/ssm.py +0 -0
  69. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/types.py +0 -0
  70. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/__init__.py +0 -0
  71. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/bijectors.py +0 -0
  72. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/distributions.py +0 -0
  73. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/distributions_test.py +0 -0
  74. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/optimize.py +0 -0
  75. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/plotting.py +0 -0
  76. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/utils.py +0 -0
  77. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/utils/utils_test.py +0 -0
  78. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/dynamax/warnings.py +0 -0
  79. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/__init__.py +0 -0
  80. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +0 -0
  81. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +0 -0
  82. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +0 -0
  83. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +0 -0
  84. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +0 -0
  85. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +0 -0
  86. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +0 -0
  87. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +0 -0
  88. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/ssm_temissions.py +0 -0
  89. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/__init__.py +0 -0
  90. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/data_driven_models.py +0 -0
  91. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/data_generator.py +0 -0
  92. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/debug_utils.py +0 -0
  93. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/demo_utils.py +0 -0
  94. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/diffrax_utils.py +0 -0
  95. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/evaluation_utils.py +0 -0
  96. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/experiment_utils.py +0 -0
  97. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/likelihood_eval_utils.py +0 -0
  98. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/optimize_utils.py +0 -0
  99. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/physics_based_models.py +0 -0
  100. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/plotting_chaos_utils.py +0 -0
  101. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/plotting_utils.py +0 -0
  102. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/prior_utils.py +0 -0
  103. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/simulation_utils.py +0 -0
  104. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax/src/utils/test_utils.py +0 -0
  105. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/dependency_links.txt +0 -0
  106. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/cd_dynamax.egg-info/top_level.txt +0 -0
  107. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/setup.cfg +0 -0
  108. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_cdlgssm_dlgssm_match.py +0 -0
  109. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_cdnonlinear_cdlinear_match.py +0 -0
  110. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_filter_forecast_emissions.py +0 -0
  111. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_imports.py +0 -0
  112. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_models.py +0 -0
  113. {cd_dynamax-0.3.2 → cd_dynamax-0.3.3}/tests/test_utils_imports.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cd_dynamax
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Continuous-discrete dynamical systems with JAX and related libraries.
5
5
  Author: Matthew Levine, Iñigo Urteaga
6
6
  Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
32
32
  Requires-Dist: fastprogress>=1.0.0
33
33
  Requires-Dist: graphviz
34
34
  Requires-Dist: ipykernel
35
+ Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
35
36
  Provides-Extra: dev
36
37
  Requires-Dist: pytest>=8.0; extra == "dev"
37
38
  Requires-Dist: ruff; extra == "dev"
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
69
70
 
70
71
  where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
71
72
 
72
- We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
74
 
74
75
  Note:
75
76
 
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
90
91
 
91
92
  All of these problems are deeply interconnected.
92
93
 
93
- - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
94
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
94
95
 
95
96
  - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
96
97
  - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
113
114
 
114
115
  - The codebase is organized into several key directories:
115
116
  ```
116
- cd_dynamax/
117
- ├── src/ # Source code for cd-dynamax library
118
- │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
119
- │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
120
- │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
121
- └── utils/ # Utility functions and example models
122
- ├── dynamax/ # Original dynamax library (as a submodule)
123
- demos/ # Python demos showcasing cd-dynamax functionality
124
- ├── python/scripts/ # Python scripts for running demos
125
- ├── python/notebooks/ # Jupyter notebooks for interactive demos
126
- ├── python/configs/ # Configuration files for demos
127
- tests/ # Tests for cd-dynamax functionality
117
+ .
118
+ ├── cd_dynamax/ # Source code for cd-dynamax library
119
+ │ ├── src/ # Core source code
120
+ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
121
+ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
122
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
123
+ │ │ └── utils/ # Utility functions and example models
124
+ │ └── dynamax/ # Original dynamax library (as a submodule)
125
+ ├── demos/ # Python demos showcasing cd-dynamax functionality
126
+ ├── python/scripts/ # Python scripts for running demos
127
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
128
+ │ └── python/configs/ # Configuration files for demos
129
+ └── tests/ # Tests for cd-dynamax functionality
128
130
  ```
129
131
 
130
132
  ## [Demos](./demos/python)
@@ -155,7 +157,7 @@ make test
155
157
  ```
156
158
 
157
159
  - For linting, we use `ruff`:
158
- ```bashbash
160
+ ```bash
159
161
  make lint
160
162
  ```
161
163
 
@@ -171,40 +173,57 @@ make build_docs
171
173
 
172
174
  # Installation
173
175
 
174
- We support installation via **Conda** (recommended) or via a standard Python virtual environment.
176
+ Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
175
177
 
176
178
  ---
177
179
 
178
- ### Option 1: Conda (recommended)
180
+ ### Option 1: Install from PyPI (recommended)
179
181
 
180
182
  ```bash
181
- # Create and activate a new environment with Python 3.11
182
- conda create -n cd_dynamax_joss python=3.11
183
- conda activate cd_dynamax_joss
183
+ # Create and activate a virtual environment
184
+ python -m venv .venv # or `uv venv`
185
+ source .venv/bin/activate # on macOS/Linux
186
+ .venv\Scripts\activate # on Windows
184
187
 
185
- # Install your package in editable mode (so local changes are picked up)
186
- pip install -e .[dev]
188
+ # Upgrade pip
189
+ pip install --upgrade pip
190
+
191
+ # Install latest release from PyPI
192
+ pip install cd-dynamax
187
193
  ```
188
194
 
189
- This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
195
+ `cd-dynamax` is currently **not available on Conda Forge**.
190
196
 
191
197
  ---
192
198
 
193
- ### Option 2: Python venv + pip
199
+ ### Option 2: Install from source (editable)
194
200
 
195
201
  ```bash
196
202
  # Create and activate a virtual environment
197
- python -m venv .venv
203
+ python -m venv .venv # or `uv venv`
198
204
  source .venv/bin/activate # on macOS/Linux
199
205
  .venv\Scripts\activate # on Windows
200
206
 
201
207
  # Upgrade pip
202
208
  pip install --upgrade pip
203
209
 
204
- # Install in editable mode
210
+ # Install in editable mode for local development
205
211
  pip install -e .[dev]
206
212
  ```
207
213
 
214
+ ---
215
+
216
+ ### Option 3: Conda environment + pip install
217
+
218
+ ```bash
219
+ # Create and activate a Conda environment with Python 3.11
220
+ conda create -n cd_dynamax python=3.11
221
+ conda activate cd_dynamax
222
+
223
+ # Install latest release from PyPI
224
+ pip install cd-dynamax
225
+ ```
226
+
208
227
  #### GPU support
209
228
  If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
210
229
 
@@ -25,7 +25,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
25
25
 
26
26
  where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
27
27
 
28
- We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
28
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
29
29
 
30
30
  Note:
31
31
 
@@ -46,7 +46,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
46
46
 
47
47
  All of these problems are deeply interconnected.
48
48
 
49
- - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
49
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
50
50
 
51
51
  - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
52
52
  - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
@@ -69,18 +69,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
69
69
 
70
70
  - The codebase is organized into several key directories:
71
71
  ```
72
- cd_dynamax/
73
- ├── src/ # Source code for cd-dynamax library
74
- │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
75
- │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
76
- │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
77
- └── utils/ # Utility functions and example models
78
- ├── dynamax/ # Original dynamax library (as a submodule)
79
- demos/ # Python demos showcasing cd-dynamax functionality
80
- ├── python/scripts/ # Python scripts for running demos
81
- ├── python/notebooks/ # Jupyter notebooks for interactive demos
82
- ├── python/configs/ # Configuration files for demos
83
- tests/ # Tests for cd-dynamax functionality
72
+ .
73
+ ├── cd_dynamax/ # Source code for cd-dynamax library
74
+ │ ├── src/ # Core source code
75
+ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
76
+ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
77
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
78
+ │ │ └── utils/ # Utility functions and example models
79
+ │ └── dynamax/ # Original dynamax library (as a submodule)
80
+ ├── demos/ # Python demos showcasing cd-dynamax functionality
81
+ ├── python/scripts/ # Python scripts for running demos
82
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
83
+ │ └── python/configs/ # Configuration files for demos
84
+ └── tests/ # Tests for cd-dynamax functionality
84
85
  ```
85
86
 
86
87
  ## [Demos](./demos/python)
@@ -111,7 +112,7 @@ make test
111
112
  ```
112
113
 
113
114
  - For linting, we use `ruff`:
114
- ```bashbash
115
+ ```bash
115
116
  make lint
116
117
  ```
117
118
 
@@ -127,40 +128,57 @@ make build_docs
127
128
 
128
129
  # Installation
129
130
 
130
- We support installation via **Conda** (recommended) or via a standard Python virtual environment.
131
+ Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
131
132
 
132
133
  ---
133
134
 
134
- ### Option 1: Conda (recommended)
135
+ ### Option 1: Install from PyPI (recommended)
135
136
 
136
137
  ```bash
137
- # Create and activate a new environment with Python 3.11
138
- conda create -n cd_dynamax_joss python=3.11
139
- conda activate cd_dynamax_joss
138
+ # Create and activate a virtual environment
139
+ python -m venv .venv # or `uv venv`
140
+ source .venv/bin/activate # on macOS/Linux
141
+ .venv\Scripts\activate # on Windows
140
142
 
141
- # Install your package in editable mode (so local changes are picked up)
142
- pip install -e .[dev]
143
+ # Upgrade pip
144
+ pip install --upgrade pip
145
+
146
+ # Install latest release from PyPI
147
+ pip install cd-dynamax
143
148
  ```
144
149
 
145
- This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
150
+ `cd-dynamax` is currently **not available on Conda Forge**.
146
151
 
147
152
  ---
148
153
 
149
- ### Option 2: Python venv + pip
154
+ ### Option 2: Install from source (editable)
150
155
 
151
156
  ```bash
152
157
  # Create and activate a virtual environment
153
- python -m venv .venv
158
+ python -m venv .venv # or `uv venv`
154
159
  source .venv/bin/activate # on macOS/Linux
155
160
  .venv\Scripts\activate # on Windows
156
161
 
157
162
  # Upgrade pip
158
163
  pip install --upgrade pip
159
164
 
160
- # Install in editable mode
165
+ # Install in editable mode for local development
161
166
  pip install -e .[dev]
162
167
  ```
163
168
 
169
+ ---
170
+
171
+ ### Option 3: Conda environment + pip install
172
+
173
+ ```bash
174
+ # Create and activate a Conda environment with Python 3.11
175
+ conda create -n cd_dynamax python=3.11
176
+ conda activate cd_dynamax
177
+
178
+ # Install latest release from PyPI
179
+ pip install cd-dynamax
180
+ ```
181
+
164
182
  #### GPU support
165
183
  If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
166
184
 
@@ -119,9 +119,9 @@ def _predict(
119
119
  f = params.dynamics.drift.f
120
120
 
121
121
  # Get time-varying parameters
122
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t)
122
+ Qc_t = params.dynamics.diffusion_cov.f(m, u, t)
123
123
  L_t = (
124
- params.dynamics.diffusion_coefficient.f(None, u, t)
124
+ params.dynamics.diffusion_coefficient.f(m, u, t)
125
125
  * filter_hyperparams.cov_rescaling
126
126
  )
127
127
 
@@ -185,9 +185,9 @@ def _predict(
185
185
  dt = filter_hyperparams.dt_average
186
186
 
187
187
  # Covariance parameters at time t0
188
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
188
+ Qc_t = params.dynamics.diffusion_cov.f(m, u, t0)
189
189
  L_t = (
190
- params.dynamics.diffusion_coefficient.f(None, u, t0)
190
+ params.dynamics.diffusion_coefficient.f(m, u, t0)
191
191
  * filter_hyperparams.cov_rescaling
192
192
  )
193
193
  # Covariance update
@@ -527,12 +527,11 @@ def _smooth(
527
527
  m_smooth, P_smooth = y
528
528
  m_filter, P_filter = args
529
529
 
530
- # TODO: possibly time- and parameter-dependent functions
531
530
  f = params.dynamics.drift.f
532
531
  # Get time-varying parameters
533
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t)
532
+ Qc_t = params.dynamics.diffusion_cov.f(m_filter, u, t)
534
533
  L_t = (
535
- params.dynamics.diffusion_coefficient.f(None, u, t)
534
+ params.dynamics.diffusion_coefficient.f(m_filter, u, t)
536
535
  * filter_hyperparams.cov_rescaling
537
536
  )
538
537
 
@@ -831,7 +830,7 @@ def extended_kalman_posterior_sample(
831
830
 
832
831
  # Get parameters and inputs for time t0
833
832
  u = inputs[t0_idx]
834
- Q = params.dynamics.diffusion_cov.f(None, u, t0)
833
+ Q = params.dynamics.diffusion_cov.f(filtered_mean, u, t0)
835
834
 
836
835
  # Condition on next state
837
836
  smoothed_mean, smoothed_cov = _condition_on(
@@ -113,9 +113,9 @@ def _predict(
113
113
  # First order EnKF diffusion
114
114
  def diffusion(t, y, args):
115
115
  # Get parameters at time t and input u
116
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t)
116
+ Qc_t = params.dynamics.diffusion_cov.f(y, u, t)
117
117
  L_t = (
118
- params.dynamics.diffusion_coefficient.f(None, u, t)
118
+ params.dynamics.diffusion_coefficient.f(y, u, t)
119
119
  * filter_hyperparams.cov_rescaling
120
120
  )
121
121
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
@@ -155,26 +155,20 @@ def _predict(
155
155
  # but the same amount of noise is added after each measurement.
156
156
  dt = filter_hyperparams.dt_average
157
157
 
158
- # Get diffusion parameters at time t0 and input u
159
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
160
- L_t = (
161
- params.dynamics.diffusion_coefficient.f(None, u, t0)
162
- * filter_hyperparams.cov_rescaling
163
- )
158
+ key_array = jr.split(key, x.shape[0])
164
159
 
165
- # Compute state noise covariance
166
- state_noise_cov = dt * L_t @ Qc_t @ L_t.T # D_hid x D_hid
160
+ def _sample_noise(x_i, key_i):
161
+ Qc_t = params.dynamics.diffusion_cov.f(x_i, u, t0)
162
+ L_t = (
163
+ params.dynamics.diffusion_coefficient.f(x_i, u, t0)
164
+ * filter_hyperparams.cov_rescaling
165
+ )
166
+ state_noise_cov = dt * L_t @ Qc_t @ L_t.T # D_hid x D_hid
167
+ return jr.multivariate_normal(
168
+ key=key_i, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
169
+ )
167
170
 
168
- # Noise realizations from MVN(0, state_noise_cov), for each particle particles (N x D_hid)
169
- # Split keys for each particle
170
- key_array = jr.split(key, x.shape[0])
171
- # vmap over particles
172
- noise = vmap(
173
- lambda key: jr.multivariate_normal(
174
- key=key, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
175
- ),
176
- in_axes=0,
177
- )(key_array)
171
+ noise = vmap(_sample_noise, in_axes=(0, 0))(x_pred, key_array)
178
172
 
179
173
  # Add noise to predicted particles
180
174
  x_pred += noise
@@ -214,9 +214,9 @@ def _predict(
214
214
  dt = filter_hyperparams.dt_average
215
215
 
216
216
  # Get diffusion parameters at time t0 and input u
217
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
217
+ Qc_t = params.dynamics.diffusion_cov.f(m, u, t0)
218
218
  L_t = (
219
- params.dynamics.diffusion_coefficient.f(None, u, t0)
219
+ params.dynamics.diffusion_coefficient.f(m, u, t0)
220
220
  * filter_hyperparams.cov_rescaling
221
221
  )
222
222
  # Compute state noise covariance
@@ -237,9 +237,9 @@ def _predict(
237
237
  f = (
238
238
  params.dynamics.drift.f
239
239
  ) # TODO: reconsider when we want time-varying dynamics functions
240
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t)
240
+ Qc_t = params.dynamics.diffusion_cov.f(m_t, u, t)
241
241
  L_t = (
242
- params.dynamics.diffusion_coefficient.f(None, u, t)
242
+ params.dynamics.diffusion_coefficient.f(m_t, u, t)
243
243
  * filter_hyperparams.cov_rescaling
244
244
  )
245
245
 
@@ -111,8 +111,8 @@ def compute_pushforward(
111
111
  f = params.dynamics.drift.f
112
112
 
113
113
  # Get time-varying parameters
114
- Qc_t = params.dynamics.diffusion_cov.f(None, inputs, t)
115
- L_t = params.dynamics.diffusion_coefficient.f(None, inputs, t)
114
+ Qc_t = params.dynamics.diffusion_cov.f(x, inputs, t)
115
+ L_t = params.dynamics.diffusion_coefficient.f(x, inputs, t)
116
116
 
117
117
  # Different SDE approximations to the dynamics
118
118
  # Zeroth-order (no gradient information),
@@ -964,8 +964,8 @@ def cdnlgssm_path_sample(
964
964
  return params.dynamics.drift.f(y, inpt, t)
965
965
 
966
966
  def diffusion(t, y, args):
967
- Qc_t = params.dynamics.diffusion_cov.f(None, inpt, t)
968
- L_t = params.dynamics.diffusion_coefficient.f(None, inpt, t)
967
+ Qc_t = params.dynamics.diffusion_cov.f(y, inpt, t)
968
+ L_t = params.dynamics.diffusion_coefficient.f(y, inpt, t)
969
969
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
970
970
  combined_diffusion = L_t @ Q_sqrt
971
971
  return combined_diffusion
@@ -1261,9 +1261,9 @@ def cdnlgssm_forecast(
1261
1261
  return params.dynamics.drift.f(y, inputs[t0_idx], t)
1262
1262
 
1263
1263
  def diffusion(t, y, args):
1264
- Qc_t = params.dynamics.diffusion_cov.f(None, inputs[t0_idx], t)
1264
+ Qc_t = params.dynamics.diffusion_cov.f(y, inputs[t0_idx], t)
1265
1265
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
1266
- L_t = params.dynamics.diffusion_coefficient.f(None, inputs[t0_idx], t)
1266
+ L_t = params.dynamics.diffusion_coefficient.f(y, inputs[t0_idx], t)
1267
1267
  combined_diffusion = L_t @ Q_sqrt
1268
1268
  return combined_diffusion
1269
1269
 
@@ -112,9 +112,9 @@ def _predict(
112
112
  else:
113
113
 
114
114
  def diffusion(t, y, args):
115
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t)
115
+ Qc_t = params.dynamics.diffusion_cov.f(y, u, t)
116
116
  L_t = (
117
- params.dynamics.diffusion_coefficient.f(None, u, t)
117
+ params.dynamics.diffusion_coefficient.f(y, u, t)
118
118
  * filter_hyperparams.cov_rescaling
119
119
  )
120
120
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
@@ -143,20 +143,20 @@ def _predict(
143
143
  else:
144
144
  dt = filter_hyperparams.dt_average
145
145
 
146
- Qc_t = params.dynamics.diffusion_cov.f(None, u, t0)
147
- L_t = (
148
- params.dynamics.diffusion_coefficient.f(None, u, t0)
149
- * filter_hyperparams.cov_rescaling
150
- )
151
- state_noise_cov = dt * L_t @ Qc_t @ L_t.T # D_hid x D_hid
152
-
153
146
  key_array = jr.split(key, x.shape[0])
154
- noise = vmap(
155
- lambda key: jr.multivariate_normal(
156
- key=key, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
157
- ),
158
- in_axes=0,
159
- )(key_array)
147
+
148
+ def _sample_noise(x_i, key_i):
149
+ Qc_t = params.dynamics.diffusion_cov.f(x_i, u, t0)
150
+ L_t = (
151
+ params.dynamics.diffusion_coefficient.f(x_i, u, t0)
152
+ * filter_hyperparams.cov_rescaling
153
+ )
154
+ state_noise_cov = dt * L_t @ Qc_t @ L_t.T
155
+ return jr.multivariate_normal(
156
+ key=key_i, mean=jnp.zeros(x.shape[1]), cov=state_noise_cov
157
+ )
158
+
159
+ noise = vmap(_sample_noise, in_axes=(0, 0))(x_pred, key_array)
160
160
  x_pred += noise
161
161
  return x_pred
162
162
 
@@ -387,8 +387,8 @@ class ContDiscreteNonlinearSSM(SSM):
387
387
  return params.dynamics.drift.f(y, u_prev_t, t)
388
388
 
389
389
  def diffusion(t, y, _):
390
- Qc_t = params.dynamics.diffusion_cov.f(None, u_prev_t, t)
391
- L_t = params.dynamics.diffusion_coefficient.f(None, u_prev_t, t)
390
+ Qc_t = params.dynamics.diffusion_cov.f(y, u_prev_t, t)
391
+ L_t = params.dynamics.diffusion_coefficient.f(y, u_prev_t, t)
392
392
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
393
393
  return L_t @ Q_sqrt
394
394
 
@@ -818,10 +818,10 @@ def cdnlssm_forecast(
818
818
  return params.dynamics.drift.f(y, forecast_inputs[t0_idx], t)
819
819
 
820
820
  def diffusion(t, y, args):
821
- Qc_t = params.dynamics.diffusion_cov.f(None, forecast_inputs[t0_idx], t)
821
+ Qc_t = params.dynamics.diffusion_cov.f(y, forecast_inputs[t0_idx], t)
822
822
  Q_sqrt = jnp.linalg.cholesky(Qc_t)
823
823
  L_t = params.dynamics.diffusion_coefficient.f(
824
- None, forecast_inputs[t0_idx], t
824
+ y, forecast_inputs[t0_idx], t
825
825
  )
826
826
  combined_diffusion = L_t @ Q_sqrt
827
827
  return combined_diffusion
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cd_dynamax
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Continuous-discrete dynamical systems with JAX and related libraries.
5
5
  Author: Matthew Levine, Iñigo Urteaga
6
6
  Maintainer-email: Matthew Levine <matt@basis.ai>, Iñigo Urteaga <iurteaga@bcamath.org>
@@ -32,6 +32,7 @@ Requires-Dist: dm-tree>=0.1.8
32
32
  Requires-Dist: fastprogress>=1.0.0
33
33
  Requires-Dist: graphviz
34
34
  Requires-Dist: ipykernel
35
+ Requires-Dist: orbax-checkpoint<0.11.3; sys_platform == "win32"
35
36
  Provides-Extra: dev
36
37
  Requires-Dist: pytest>=8.0; extra == "dev"
37
38
  Requires-Dist: ruff; extra == "dev"
@@ -69,7 +70,7 @@ $$y(t) = h(x(t)) + \eta(t)$$
69
70
 
70
71
  where $h: \mathbb{R}^{d_x} \mapsto \mathbb{R}^{d_y}$ creates a $d_y$-dimensional observation from the $d_x$-dimensional state of the dynamical system $x(t)$ (a realization of the above SDE), and $\eta(t)$ applies additive Gaussian noise to the observation.
71
72
 
72
- We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ L,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
+ We denote the collection of all parameters as $\theta = \\{f,\\ L,\\ \mu_0,\\ \Sigma_0,\\ Q,\\ h,\\ \textrm{Law}(\eta) \\}$.
73
74
 
74
75
  Note:
75
76
 
@@ -90,7 +91,7 @@ For a given set of observations $Y_K = [y(t_1),\\ \dots ,\\ y(t_K)]$, we wish to
90
91
 
91
92
  All of these problems are deeply interconnected.
92
93
 
93
- - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$.
94
+ - In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations ($[Y^{(1)}, \\ \dots \\, \\ Y^{(N)}]$).
94
95
 
95
96
  - In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
96
97
  - In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters $\theta^{(n)}$.
@@ -113,18 +114,19 @@ The `cd-dynamax` codebase extends the `dynamax` library to support continuous-di
113
114
 
114
115
  - The codebase is organized into several key directories:
115
116
  ```
116
- cd_dynamax/
117
- ├── src/ # Source code for cd-dynamax library
118
- │ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
119
- │ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
120
- │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
121
- └── utils/ # Utility functions and example models
122
- ├── dynamax/ # Original dynamax library (as a submodule)
123
- demos/ # Python demos showcasing cd-dynamax functionality
124
- ├── python/scripts/ # Python scripts for running demos
125
- ├── python/notebooks/ # Jupyter notebooks for interactive demos
126
- ├── python/configs/ # Configuration files for demos
127
- tests/ # Tests for cd-dynamax functionality
117
+ .
118
+ ├── cd_dynamax/ # Source code for cd-dynamax library
119
+ │ ├── src/ # Core source code
120
+ ├── continuous_discrete_linear_gaussian_ssm/ # CD-LGSSM models and algorithms
121
+ ├── continuous_discrete_nonlinear_gaussian_ssm/ # CD-NLGSSM models and algorithms
122
+ │ ├── ssm_temissions.py # Modified SSM class for discrete emissions
123
+ │ │ └── utils/ # Utility functions and example models
124
+ │ └── dynamax/ # Original dynamax library (as a submodule)
125
+ ├── demos/ # Python demos showcasing cd-dynamax functionality
126
+ ├── python/scripts/ # Python scripts for running demos
127
+ ├── python/notebooks/ # Jupyter notebooks for interactive demos
128
+ │ └── python/configs/ # Configuration files for demos
129
+ └── tests/ # Tests for cd-dynamax functionality
128
130
  ```
129
131
 
130
132
  ## [Demos](./demos/python)
@@ -155,7 +157,7 @@ make test
155
157
  ```
156
158
 
157
159
  - For linting, we use `ruff`:
158
- ```bashbash
160
+ ```bash
159
161
  make lint
160
162
  ```
161
163
 
@@ -171,40 +173,57 @@ make build_docs
171
173
 
172
174
  # Installation
173
175
 
174
- We support installation via **Conda** (recommended) or via a standard Python virtual environment.
176
+ Install from **PyPI** (recommended), from source in editable mode, or with a Conda-managed environment.
175
177
 
176
178
  ---
177
179
 
178
- ### Option 1: Conda (recommended)
180
+ ### Option 1: Install from PyPI (recommended)
179
181
 
180
182
  ```bash
181
- # Create and activate a new environment with Python 3.11
182
- conda create -n cd_dynamax_joss python=3.11
183
- conda activate cd_dynamax_joss
183
+ # Create and activate a virtual environment
184
+ python -m venv .venv # or `uv venv`
185
+ source .venv/bin/activate # on macOS/Linux
186
+ .venv\Scripts\activate # on Windows
184
187
 
185
- # Install your package in editable mode (so local changes are picked up)
186
- pip install -e .[dev]
188
+ # Upgrade pip
189
+ pip install --upgrade pip
190
+
191
+ # Install latest release from PyPI
192
+ pip install cd-dynamax
187
193
  ```
188
194
 
189
- This installs the core dependencies listed in `pyproject.toml`, along with optional developer tools (`pytest`, etc.) if you use `[dev]`.
195
+ `cd-dynamax` is currently **not available on Conda Forge**.
190
196
 
191
197
  ---
192
198
 
193
- ### Option 2: Python venv + pip
199
+ ### Option 2: Install from source (editable)
194
200
 
195
201
  ```bash
196
202
  # Create and activate a virtual environment
197
- python -m venv .venv
203
+ python -m venv .venv # or `uv venv`
198
204
  source .venv/bin/activate # on macOS/Linux
199
205
  .venv\Scripts\activate # on Windows
200
206
 
201
207
  # Upgrade pip
202
208
  pip install --upgrade pip
203
209
 
204
- # Install in editable mode
210
+ # Install in editable mode for local development
205
211
  pip install -e .[dev]
206
212
  ```
207
213
 
214
+ ---
215
+
216
+ ### Option 3: Conda environment + pip install
217
+
218
+ ```bash
219
+ # Create and activate a Conda environment with Python 3.11
220
+ conda create -n cd_dynamax python=3.11
221
+ conda activate cd_dynamax
222
+
223
+ # Install latest release from PyPI
224
+ pip install cd-dynamax
225
+ ```
226
+
208
227
  #### GPU support
209
228
  If you want GPU acceleration with JAX, you must install a CUDA-enabled `jaxlib` wheel.
210
229
 
@@ -107,4 +107,5 @@ tests/test_cdnonlinear_cdlinear_match.py
107
107
  tests/test_filter_forecast_emissions.py
108
108
  tests/test_imports.py
109
109
  tests/test_models.py
110
+ tests/test_state_dependent_diffusion.py
110
111
  tests/test_utils_imports.py
@@ -24,6 +24,9 @@ fastprogress>=1.0.0
24
24
  graphviz
25
25
  ipykernel
26
26
 
27
+ [:sys_platform == "win32"]
28
+ orbax-checkpoint<0.11.3
29
+
27
30
  [dev]
28
31
  pytest>=8.0
29
32
  ruff
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "cd_dynamax"
7
- version = "0.3.2"
7
+ version = "0.3.3"
8
8
  requires-python = ">=3.11"
9
9
  description = "Continuous-discrete dynamical systems with JAX and related libraries."
10
10
  readme = "README.md"
@@ -42,7 +42,8 @@ dependencies = [
42
42
  "dm-tree>=0.1.8",
43
43
  "fastprogress>=1.0.0",
44
44
  "graphviz",
45
- "ipykernel"
45
+ "ipykernel",
46
+ "orbax-checkpoint<0.11.3; sys_platform == 'win32'"
46
47
  ]
47
48
 
48
49
  [tool.setuptools.packages.find]
@@ -0,0 +1,93 @@
1
+ import jax.numpy as jnp
2
+ import jax.random as jr
3
+ from typing import NamedTuple
4
+
5
+ from cd_dynamax.src.continuous_discrete_nonlinear_gaussian_ssm import (
6
+ ContDiscreteNonlinearGaussianSSM,
7
+ )
8
+ from cd_dynamax.src.continuous_discrete_nonlinear_gaussian_ssm.inference_ekf import (
9
+ EKFHyperParams,
10
+ extended_kalman_filter,
11
+ )
12
+ from cd_dynamax.src.continuous_discrete_nonlinear_ssm import ContDiscreteNonlinearSSM
13
+ from cd_dynamax.src.continuous_discrete_nonlinear_ssm.inference_dpf import (
14
+ DPFHyperParams,
15
+ diff_particle_filter,
16
+ )
17
+
18
+
19
+ class StateDependentDiagonalDiffusion(NamedTuple):
20
+ scale: float = 0.1
21
+
22
+ def f(self, x, u=None, t=None):
23
+ if x is None:
24
+ raise ValueError("state-dependent diffusion requires a state argument")
25
+ x = jnp.atleast_1d(x)
26
+ return jnp.diag(1.0 + self.scale * jnp.square(x))
27
+
28
+
29
+ def test_cdnlgssm_sample_path_supports_state_dependent_diffusion():
30
+ model = ContDiscreteNonlinearGaussianSSM(state_dim=1, emission_dim=1)
31
+ params, _ = model.initialize()
32
+ params = params._replace(
33
+ dynamics=params.dynamics._replace(
34
+ diffusion_coefficient=StateDependentDiagonalDiffusion()
35
+ )
36
+ )
37
+
38
+ states, emissions = model.sample_path(params, key=jr.PRNGKey(0), num_timesteps=3)
39
+
40
+ assert states.shape == (3, 1)
41
+ assert emissions.shape == (3, 1)
42
+ assert jnp.all(jnp.isfinite(states))
43
+ assert jnp.all(jnp.isfinite(emissions))
44
+
45
+
46
+ def test_extended_kalman_filter_supports_state_dependent_diffusion():
47
+ model = ContDiscreteNonlinearGaussianSSM(state_dim=1, emission_dim=1)
48
+ params, _ = model.initialize()
49
+ params = params._replace(
50
+ dynamics=params.dynamics._replace(
51
+ diffusion_coefficient=StateDependentDiagonalDiffusion()
52
+ )
53
+ )
54
+
55
+ emissions = jnp.zeros((2, 1))
56
+ t_emissions = jnp.array([[0.0], [0.1]])
57
+
58
+ posterior = extended_kalman_filter(
59
+ params,
60
+ emissions,
61
+ t_emissions=t_emissions,
62
+ filter_hyperparams=EKFHyperParams(state_order="first"),
63
+ )
64
+
65
+ assert posterior.filtered_means.shape == (2, 1)
66
+ assert posterior.predicted_covariances.shape == (2, 1, 1)
67
+ assert jnp.all(jnp.isfinite(posterior.filtered_means))
68
+ assert jnp.all(jnp.isfinite(posterior.predicted_covariances))
69
+
70
+
71
+ def test_diff_particle_filter_supports_state_dependent_diffusion():
72
+ model = ContDiscreteNonlinearSSM(state_dim=1, emission_dim=1)
73
+ params, _ = model.initialize()
74
+ params = params._replace(
75
+ dynamics=params.dynamics._replace(
76
+ diffusion_coefficient=StateDependentDiagonalDiffusion()
77
+ )
78
+ )
79
+
80
+ emissions = jnp.zeros((2, 1))
81
+ ts = jnp.array([0.0, 0.1])
82
+
83
+ particles, log_weights, log_evidence = diff_particle_filter(
84
+ jr.PRNGKey(0),
85
+ params,
86
+ emissions,
87
+ ts=ts,
88
+ hyperparams=DPFHyperParams(N_particles=8),
89
+ )
90
+
91
+ assert particles.shape == (2, 8, 1)
92
+ assert log_weights.shape == (2, 8)
93
+ assert jnp.isfinite(log_evidence)
File without changes
File without changes