cd-dynamax 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. cd_dynamax/__init__.py +82 -0
  2. cd_dynamax/dynamax/__init__.py +9 -0
  3. cd_dynamax/dynamax/_version.py +658 -0
  4. cd_dynamax/dynamax/generalized_gaussian_ssm/__init__.py +6 -0
  5. cd_dynamax/dynamax/generalized_gaussian_ssm/inference.py +386 -0
  6. cd_dynamax/dynamax/generalized_gaussian_ssm/inference_test.py +81 -0
  7. cd_dynamax/dynamax/generalized_gaussian_ssm/models.py +131 -0
  8. cd_dynamax/dynamax/generalized_gaussian_ssm/models_test.py +101 -0
  9. cd_dynamax/dynamax/hidden_markov_model/__init__.py +27 -0
  10. cd_dynamax/dynamax/hidden_markov_model/inference.py +629 -0
  11. cd_dynamax/dynamax/hidden_markov_model/inference_test.py +316 -0
  12. cd_dynamax/dynamax/hidden_markov_model/models/__init__.py +0 -0
  13. cd_dynamax/dynamax/hidden_markov_model/models/abstractions.py +706 -0
  14. cd_dynamax/dynamax/hidden_markov_model/models/arhmm.py +231 -0
  15. cd_dynamax/dynamax/hidden_markov_model/models/bernoulli_hmm.py +163 -0
  16. cd_dynamax/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +170 -0
  17. cd_dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py +177 -0
  18. cd_dynamax/dynamax/hidden_markov_model/models/gamma_hmm.py +144 -0
  19. cd_dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py +1031 -0
  20. cd_dynamax/dynamax/hidden_markov_model/models/gmm_hmm.py +500 -0
  21. cd_dynamax/dynamax/hidden_markov_model/models/initial.py +73 -0
  22. cd_dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py +221 -0
  23. cd_dynamax/dynamax/hidden_markov_model/models/logreg_hmm.py +175 -0
  24. cd_dynamax/dynamax/hidden_markov_model/models/multinomial_hmm.py +159 -0
  25. cd_dynamax/dynamax/hidden_markov_model/models/poisson_hmm.py +164 -0
  26. cd_dynamax/dynamax/hidden_markov_model/models/test_models.py +165 -0
  27. cd_dynamax/dynamax/hidden_markov_model/models/transitions.py +85 -0
  28. cd_dynamax/dynamax/hidden_markov_model/parallel_inference.py +194 -0
  29. cd_dynamax/dynamax/linear_gaussian_ssm/__init__.py +22 -0
  30. cd_dynamax/dynamax/linear_gaussian_ssm/builders.py +105 -0
  31. cd_dynamax/dynamax/linear_gaussian_ssm/inference.py +621 -0
  32. cd_dynamax/dynamax/linear_gaussian_ssm/inference_test.py +252 -0
  33. cd_dynamax/dynamax/linear_gaussian_ssm/info_inference.py +418 -0
  34. cd_dynamax/dynamax/linear_gaussian_ssm/info_inference_test.py +172 -0
  35. cd_dynamax/dynamax/linear_gaussian_ssm/models.py +615 -0
  36. cd_dynamax/dynamax/linear_gaussian_ssm/models_test.py +24 -0
  37. cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference.py +382 -0
  38. cd_dynamax/dynamax/linear_gaussian_ssm/parallel_inference_test.py +193 -0
  39. cd_dynamax/dynamax/nonlinear_gaussian_ssm/__init__.py +9 -0
  40. cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +349 -0
  41. cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py +117 -0
  42. cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +178 -0
  43. cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +291 -0
  44. cd_dynamax/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +33 -0
  45. cd_dynamax/dynamax/nonlinear_gaussian_ssm/models.py +116 -0
  46. cd_dynamax/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +203 -0
  47. cd_dynamax/dynamax/parameters.py +125 -0
  48. cd_dynamax/dynamax/parameters_test.py +135 -0
  49. cd_dynamax/dynamax/slds/__init__.py +2 -0
  50. cd_dynamax/dynamax/slds/inference.py +339 -0
  51. cd_dynamax/dynamax/slds/inference_test.py +124 -0
  52. cd_dynamax/dynamax/slds/mixture_kalman_filter_demo.py +161 -0
  53. cd_dynamax/dynamax/slds/models.py +133 -0
  54. cd_dynamax/dynamax/ssm.py +471 -0
  55. cd_dynamax/dynamax/types.py +10 -0
  56. cd_dynamax/dynamax/utils/__init__.py +0 -0
  57. cd_dynamax/dynamax/utils/bijectors.py +34 -0
  58. cd_dynamax/dynamax/utils/distributions.py +428 -0
  59. cd_dynamax/dynamax/utils/distributions_test.py +160 -0
  60. cd_dynamax/dynamax/utils/optimize.py +111 -0
  61. cd_dynamax/dynamax/utils/plotting.py +151 -0
  62. cd_dynamax/dynamax/utils/utils.py +276 -0
  63. cd_dynamax/dynamax/utils/utils_test.py +43 -0
  64. cd_dynamax/dynamax/warnings.py +18 -0
  65. cd_dynamax/src/__init__.py +54 -0
  66. cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/__init__.py +25 -0
  67. cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/cdlgssm_utils.py +280 -0
  68. cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/inference.py +1340 -0
  69. cd_dynamax/src/continuous_discrete_linear_gaussian_ssm/models.py +712 -0
  70. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/__init__.py +29 -0
  71. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/cdnlgssm_utils.py +488 -0
  72. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ekf.py +1052 -0
  73. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_enkf.py +721 -0
  74. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/inference_ukf.py +684 -0
  75. cd_dynamax/src/continuous_discrete_nonlinear_gaussian_ssm/models.py +1462 -0
  76. cd_dynamax/src/continuous_discrete_nonlinear_ssm/__init__.py +25 -0
  77. cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py +316 -0
  78. cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py +408 -0
  79. cd_dynamax/src/continuous_discrete_nonlinear_ssm/models.py +938 -0
  80. cd_dynamax/src/ssm_temissions.py +1617 -0
  81. cd_dynamax/src/utils/__init__.py +6 -0
  82. cd_dynamax/src/utils/data_driven_models.py +402 -0
  83. cd_dynamax/src/utils/data_generator.py +173 -0
  84. cd_dynamax/src/utils/debug_utils.py +144 -0
  85. cd_dynamax/src/utils/demo_utils.py +846 -0
  86. cd_dynamax/src/utils/diffrax_utils.py +224 -0
  87. cd_dynamax/src/utils/evaluation_utils.py +18 -0
  88. cd_dynamax/src/utils/experiment_utils.py +498 -0
  89. cd_dynamax/src/utils/likelihood_eval_utils.py +638 -0
  90. cd_dynamax/src/utils/optimize_utils.py +161 -0
  91. cd_dynamax/src/utils/physics_based_models.py +322 -0
  92. cd_dynamax/src/utils/plotting_chaos_utils.py +358 -0
  93. cd_dynamax/src/utils/plotting_utils.py +1604 -0
  94. cd_dynamax/src/utils/prior_utils.py +220 -0
  95. cd_dynamax/src/utils/simulation_utils.py +492 -0
  96. cd_dynamax/src/utils/test_utils.py +225 -0
  97. cd_dynamax-0.2.5.dist-info/METADATA +228 -0
  98. cd_dynamax-0.2.5.dist-info/RECORD +101 -0
  99. cd_dynamax-0.2.5.dist-info/WHEEL +5 -0
  100. cd_dynamax-0.2.5.dist-info/licenses/LICENSE +21 -0
  101. cd_dynamax-0.2.5.dist-info/top_level.txt +1 -0
cd_dynamax/__init__.py ADDED
@@ -0,0 +1,82 @@
1
+ # cd_dynamax/__init__.py
2
+
3
+ # Nonlinear SSM
4
+ from .src.continuous_discrete_nonlinear_gaussian_ssm import (
5
+ ContDiscreteNonlinearGaussianSSM,
6
+ ParamsCDNLGSSM,
7
+ cdnlgssm_filter,
8
+ cdnlgssm_smoother,
9
+ cdnlgssm_forecast,
10
+ cdnlgssm_emissions,
11
+ EKFHyperParams,
12
+ UKFHyperParams,
13
+ EnKFHyperParams,
14
+ )
15
+
16
+ from .src.continuous_discrete_nonlinear_ssm import (
17
+ ContDiscreteNonlinearSSM,
18
+ ParamsCDNLSSM,
19
+ DPFHyperParams,
20
+ cdnlssm_filter,
21
+ )
22
+
23
+ # Linear SSM
24
+ from .src.continuous_discrete_linear_gaussian_ssm import (
25
+ ContDiscreteLinearGaussianSSM,
26
+ ParamsCDLGSSM,
27
+ cdlgssm_filter,
28
+ cdlgssm_smoother,
29
+ cdlgssm_forecast,
30
+ cdlgssm_emissions,
31
+ cdlgssm_posterior_sample,
32
+ cdlgssm_joint_sample,
33
+ KFHyperParams,
34
+ )
35
+
36
+ # Discrete-Discrete Linear SSM
37
+ from .dynamax.linear_gaussian_ssm import LinearGaussianSSM
38
+
39
+ # Shared pieces
40
+ from .src.ssm_temissions import SSM, Prior
41
+
42
+ # Utilities (the ones your demos use most)
43
+ from .src.utils.diffrax_utils import adjust_rhs
44
+ from .src.utils.optimize_utils import make_optimizer
45
+ from .src.utils.simulation_utils import make_key_sequence
46
+
47
+ __all__ = [
48
+ # Models
49
+ "ContDiscreteNonlinearGaussianSSM",
50
+ "ContDiscreteNonlinearSSM",
51
+ "ContDiscreteLinearGaussianSSM",
52
+ "LinearGaussianSSM",
53
+ # Params
54
+ "ParamsCDNLGSSM",
55
+ "ParamsCDNLSSM",
56
+ "ParamsCDLGSSM",
57
+ # Nonlinear algos
58
+ "cdnlgssm_filter",
59
+ "cdnlgssm_smoother",
60
+ "cdnlgssm_forecast",
61
+ "cdnlgssm_emissions",
62
+ "cdnlssm_filter",
63
+ "EKFHyperParams",
64
+ "UKFHyperParams",
65
+ "EnKFHyperParams",
66
+ "DPFHyperParams",
67
+ # Linear algos
68
+ "cdlgssm_filter",
69
+ "cdlgssm_smoother",
70
+ "cdlgssm_forecast",
71
+ "cdlgssm_emissions",
72
+ "cdlgssm_posterior_sample",
73
+ "cdlgssm_joint_sample",
74
+ "KFHyperParams",
75
+ # SSM/emissions
76
+ "SSM",
77
+ "Prior",
78
+ # Utils
79
+ "adjust_rhs",
80
+ "make_optimizer",
81
+ "make_key_sequence",
82
+ ]
@@ -0,0 +1,9 @@
1
+ from . import _version
2
+ __version__ = _version.get_versions()['version']
3
+
4
+ # Catch expected warnings from TFP
5
+ from . import warnings
6
+
7
+ # Default to float32 matrix multiplication on TPUs and GPUs
8
+ import jax
9
+ jax.config.update('jax_default_matmul_precision', 'float32')