cd-dynamax 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cd_dynamax/__init__.py +82 -0
- cd_dynamax/dynamax/__init__.py +9 -0
- cd_dynamax/dynamax/_version.py +658 -0
- cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +6 -0
- cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +386 -0
- cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +81 -0
- cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +131 -0
- cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +101 -0
- cd_dynamax/dynamax/hidden_markov_model/__init__.py +27 -0
- cd_dynamax/dynamax/hidden_markov_model/inference.py +629 -0
- cd_dynamax/dynamax/hidden_markov_model/inference_test.py +316 -0
- cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
- cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +706 -0
- cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +231 -0
- cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +163 -0
- cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +170 -0
- cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +177 -0
- cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +144 -0
- cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +1031 -0
- cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +500 -0
- cd_dynamax/dynamax/hidden_markov_model/models/initial.py +73 -0
- cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +221 -0
- cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +175 -0
- cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +159 -0
- cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +164 -0
- cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +165 -0
- cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +85 -0
- cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +194 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +22 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +105 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +621 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +252 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +418 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +172 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/models.py +615 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +24 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +382 -0
- cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +193 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +9 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +349 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +117 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +178 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +291 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +33 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +116 -0
- cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +203 -0
- cd_dynamax/dynamax/parameters.py +125 -0
- cd_dynamax/dynamax/parameters_test.py +135 -0
- cd_dynamax/dynamax/slds/__init__.py +2 -0
- cd_dynamax/dynamax/slds/inference.py +339 -0
- cd_dynamax/dynamax/slds/inference_test.py +124 -0
- cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +161 -0
- cd_dynamax/dynamax/slds/models.py +133 -0
- cd_dynamax/dynamax/ssm.py +471 -0
- cd_dynamax/dynamax/types.py +10 -0
- cd_dynamax/dynamax/utils/__init__.py +0 -0
- cd_dynamax/dynamax/utils/bijectors.py +34 -0
- cd_dynamax/dynamax/utils/distributions.py +428 -0
- cd_dynamax/dynamax/utils/distributions_test.py +160 -0
- cd_dynamax/dynamax/utils/optimize.py +111 -0
- cd_dynamax/dynamax/utils/plotting.py +151 -0
- cd_dynamax/dynamax/utils/utils.py +276 -0
- cd_dynamax/dynamax/utils/utils_test.py +43 -0
- cd_dynamax/dynamax/warnings.py +18 -0
- cd_dynamax/src/__init__.py +54 -0
- cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +25 -0
- cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +280 -0
- cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +1340 -0
- cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +712 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +29 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +488 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +1052 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +721 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +684 -0
- cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +1462 -0
- cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +25 -0
- cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +316 -0
- cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +408 -0
- cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +938 -0
- cd_dynamax/src/ssm_temissions.py +1617 -0
- cd_dynamax/src/utils/__init__.py +6 -0
- cd_dynamax/src/utils/data_driven_models.py +402 -0
- cd_dynamax/src/utils/data_generator.py +173 -0
- cd_dynamax/src/utils/debug_utils.py +144 -0
- cd_dynamax/src/utils/demo_utils.py +846 -0
- cd_dynamax/src/utils/diffrax_utils.py +224 -0
- cd_dynamax/src/utils/evaluation_utils.py +18 -0
- cd_dynamax/src/utils/experiment_utils.py +498 -0
- cd_dynamax/src/utils/likelihood_eval_utils.py +638 -0
- cd_dynamax/src/utils/optimize_utils.py +161 -0
- cd_dynamax/src/utils/physics_based_models.py +322 -0
- cd_dynamax/src/utils/plotting_chaos_utils.py +358 -0
- cd_dynamax/src/utils/plotting_utils.py +1604 -0
- cd_dynamax/src/utils/prior_utils.py +220 -0
- cd_dynamax/src/utils/simulation_utils.py +492 -0
- cd_dynamax/src/utils/test_utils.py +225 -0
- cd_dynamax-0.2.5.dist-info/METADATA +228 -0
- cd_dynamax-0.2.5.dist-info/RECORD +101 -0
- cd_dynamax-0.2.5.dist-info/WHEEL +5 -0
- cd_dynamax-0.2.5.dist-info/licenses/LICENSE +21 -0
- cd_dynamax-0.2.5.dist-info/top_level.txt +1 -0
cd_dynamax/__init__.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# cd_dynamax/__init__.py
|
|
2
|
+
|
|
3
|
+
# Nonlinear SSM
|
|
4
|
+
from .src.continuous_discrete_nonlinear_gaussian_ssm import (
|
|
5
|
+
ContDiscreteNonlinearGaussianSSM,
|
|
6
|
+
ParamsCDNLGSSM,
|
|
7
|
+
cdnlgssm_filter,
|
|
8
|
+
cdnlgssm_smoother,
|
|
9
|
+
cdnlgssm_forecast,
|
|
10
|
+
cdnlgssm_emissions,
|
|
11
|
+
EKFHyperParams,
|
|
12
|
+
UKFHyperParams,
|
|
13
|
+
EnKFHyperParams,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from .src.continuous_discrete_nonlinear_ssm import (
|
|
17
|
+
ContDiscreteNonlinearSSM,
|
|
18
|
+
ParamsCDNLSSM,
|
|
19
|
+
DPFHyperParams,
|
|
20
|
+
cdnlssm_filter,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Linear SSM
|
|
24
|
+
from .src.continuous_discrete_linear_gaussian_ssm import (
|
|
25
|
+
ContDiscreteLinearGaussianSSM,
|
|
26
|
+
ParamsCDLGSSM,
|
|
27
|
+
cdlgssm_filter,
|
|
28
|
+
cdlgssm_smoother,
|
|
29
|
+
cdlgssm_forecast,
|
|
30
|
+
cdlgssm_emissions,
|
|
31
|
+
cdlgssm_posterior_sample,
|
|
32
|
+
cdlgssm_joint_sample,
|
|
33
|
+
KFHyperParams,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Discrete-Discrete Linear SSM
|
|
37
|
+
from .dynamax.linear_gaussian_ssm import LinearGaussianSSM
|
|
38
|
+
|
|
39
|
+
# Shared pieces
|
|
40
|
+
from .src.ssm_temissions import SSM, Prior
|
|
41
|
+
|
|
42
|
+
# Utilities (the ones your demos use most)
|
|
43
|
+
from .src.utils.diffrax_utils import adjust_rhs
|
|
44
|
+
from .src.utils.optimize_utils import make_optimizer
|
|
45
|
+
from .src.utils.simulation_utils import make_key_sequence
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
# Models
|
|
49
|
+
"ContDiscreteNonlinearGaussianSSM",
|
|
50
|
+
"ContDiscreteNonlinearSSM",
|
|
51
|
+
"ContDiscreteLinearGaussianSSM",
|
|
52
|
+
"LinearGaussianSSM",
|
|
53
|
+
# Params
|
|
54
|
+
"ParamsCDNLGSSM",
|
|
55
|
+
"ParamsCDNLSSM",
|
|
56
|
+
"ParamsCDLGSSM",
|
|
57
|
+
# Nonlinear algos
|
|
58
|
+
"cdnlgssm_filter",
|
|
59
|
+
"cdnlgssm_smoother",
|
|
60
|
+
"cdnlgssm_forecast",
|
|
61
|
+
"cdnlgssm_emissions",
|
|
62
|
+
"cdnlssm_filter",
|
|
63
|
+
"EKFHyperParams",
|
|
64
|
+
"UKFHyperParams",
|
|
65
|
+
"EnKFHyperParams",
|
|
66
|
+
"DPFHyperParams",
|
|
67
|
+
# Linear algos
|
|
68
|
+
"cdlgssm_filter",
|
|
69
|
+
"cdlgssm_smoother",
|
|
70
|
+
"cdlgssm_forecast",
|
|
71
|
+
"cdlgssm_emissions",
|
|
72
|
+
"cdlgssm_posterior_sample",
|
|
73
|
+
"cdlgssm_joint_sample",
|
|
74
|
+
"KFHyperParams",
|
|
75
|
+
# SSM/emissions
|
|
76
|
+
"SSM",
|
|
77
|
+
"Prior",
|
|
78
|
+
# Utils
|
|
79
|
+
"adjust_rhs",
|
|
80
|
+
"make_optimizer",
|
|
81
|
+
"make_key_sequence",
|
|
82
|
+
]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from . import _version
|
|
2
|
+
__version__ = _version.get_versions()['version']
|
|
3
|
+
|
|
4
|
+
# Catch expected warnings from TFP
|
|
5
|
+
from . import warnings
|
|
6
|
+
|
|
7
|
+
# Default to float32 matrix multiplication on TPUs and GPUs
|
|
8
|
+
import jax
|
|
9
|
+
jax.config.update('jax_default_matmul_precision', 'float32')
|