pymc-extras 0.4.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/PKG-INFO +1 -1
  2. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/_version.py +16 -3
  3. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/histogram_utils.py +1 -1
  4. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/__init__.py +1 -1
  5. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/printing.py +1 -1
  6. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/__init__.py +4 -4
  7. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/core/__init__.py +1 -1
  8. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/core/statespace.py +94 -23
  9. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/kalman_filter.py +16 -11
  10. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/SARIMAX.py +138 -74
  11. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/VARMAX.py +248 -57
  12. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/__init__.py +2 -2
  13. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/autoregressive.py +49 -24
  14. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/cycle.py +48 -28
  15. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/level_trend.py +61 -29
  16. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/measurement_error.py +22 -5
  17. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/regression.py +47 -18
  18. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/seasonality.py +278 -95
  19. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/core.py +27 -8
  20. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/constants.py +17 -14
  21. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/data_tools.py +1 -1
  22. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/__init__.py +1 -1
  23. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/core/test_statespace.py +45 -14
  24. pymc_extras-0.4.1/tests/statespace/models/structural/components/test_autoregressive.py +267 -0
  25. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/components/test_cycle.py +119 -6
  26. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/components/test_level_trend.py +125 -0
  27. pymc_extras-0.4.1/tests/statespace/models/structural/components/test_measurement_error.py +74 -0
  28. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/components/test_regression.py +102 -1
  29. pymc_extras-0.4.1/tests/statespace/models/structural/components/test_seasonality.py +716 -0
  30. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/test_against_statsmodels.py +13 -13
  31. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/test_core.py +12 -5
  32. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/test_SARIMAX.py +64 -11
  33. pymc_extras-0.4.1/tests/statespace/models/test_VARMAX.py +545 -0
  34. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/utils/test_coord_assignment.py +1 -1
  35. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_histogram_approximation.py +2 -2
  36. pymc_extras-0.4.0/tests/statespace/models/structural/components/test_autoregressive.py +0 -132
  37. pymc_extras-0.4.0/tests/statespace/models/structural/components/test_measurement_error.py +0 -32
  38. pymc_extras-0.4.0/tests/statespace/models/structural/components/test_seasonality.py +0 -439
  39. pymc_extras-0.4.0/tests/statespace/models/test_VARMAX.py +0 -190
  40. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/.gitignore +0 -0
  41. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/.gitpod.yml +0 -0
  42. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/.pre-commit-config.yaml +0 -0
  43. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/.readthedocs.yaml +0 -0
  44. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/CODE_OF_CONDUCT.md +0 -0
  45. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/CONTRIBUTING.md +0 -0
  46. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/LICENSE +0 -0
  47. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/README.md +0 -0
  48. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/codecov.yml +0 -0
  49. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/conda-envs/environment-test.yml +0 -0
  50. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/.nojekyll +0 -0
  51. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/Makefile +0 -0
  52. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/_templates/autosummary/base.rst +0 -0
  53. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/_templates/autosummary/class.rst +0 -0
  54. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/api_reference.rst +0 -0
  55. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/conf.py +0 -0
  56. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/index.rst +0 -0
  57. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/make.bat +0 -0
  58. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/statespace/core.rst +0 -0
  59. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/statespace/filters.rst +0 -0
  60. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/statespace/models/structural.rst +0 -0
  61. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/docs/statespace/models.rst +0 -0
  62. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/__init__.py +0 -0
  63. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/deserialize.py +0 -0
  64. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/__init__.py +5 -5
  65. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/continuous.py +0 -0
  66. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/discrete.py +0 -0
  67. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  68. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  69. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/timeseries.py +0 -0
  70. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/transforms/__init__.py +0 -0
  71. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/distributions/transforms/partial_order.py +0 -0
  72. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/gp/__init__.py +0 -0
  73. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/gp/latent_approx.py +0 -0
  74. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/fit.py +0 -0
  75. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/__init__.py +0 -0
  76. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/find_map.py +0 -0
  77. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/idata.py +0 -0
  78. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/laplace.py +0 -0
  79. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/scipy_interface.py +0 -0
  80. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  81. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
  82. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  83. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/pathfinder.py +0 -0
  84. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/smc/__init__.py +0 -0
  85. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/inference/smc/sampling.py +0 -0
  86. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/linearmodel.py +0 -0
  87. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/__init__.py +0 -0
  88. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/marginal/__init__.py +0 -0
  89. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/marginal/distributions.py +0 -0
  90. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  91. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/marginal/marginal_model.py +0 -0
  92. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/model_api.py +0 -0
  93. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/transforms/__init__.py +0 -0
  94. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model/transforms/autoreparam.py +0 -0
  95. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/model_builder.py +0 -0
  96. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/preprocessing/__init__.py +0 -0
  97. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  98. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/prior.py +0 -0
  99. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/core/compile.py +0 -0
  100. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/core/representation.py +8 -8
  101. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/__init__.py +3 -3
  102. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/distributions.py +0 -0
  103. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
  104. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/utilities.py +0 -0
  105. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/ETS.py +0 -0
  106. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/__init__.py +4 -4
  107. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  108. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/structural/utils.py +0 -0
  109. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/models/utilities.py +0 -0
  110. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/__init__.py +0 -0
  111. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  112. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/utils/__init__.py +0 -0
  113. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/utils/linear_cg.py +0 -0
  114. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/utils/model_equivalence.py +0 -0
  115. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/utils/prior.py +0 -0
  116. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pymc_extras/utils/spline.py +0 -0
  117. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/pyproject.toml +0 -0
  118. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/__init__.py +0 -0
  119. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/conftest.py +0 -0
  120. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/test_continuous.py +0 -0
  121. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/test_discrete.py +0 -0
  122. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/test_discrete_markov_chain.py +0 -0
  123. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/test_multivariate.py +0 -0
  124. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/distributions/test_transform.py +0 -0
  125. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/__init__.py +0 -0
  126. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/laplace_approx/__init__.py +0 -0
  127. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_find_map.py +0 -0
  128. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_idata.py +0 -0
  129. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_laplace.py +0 -0
  130. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
  131. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/__init__.py +0 -0
  132. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/marginal/__init__.py +0 -0
  133. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/marginal/test_distributions.py +0 -0
  134. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/marginal/test_graph_analysis.py +0 -0
  135. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/marginal/test_marginal_model.py +0 -0
  136. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/test_model_api.py +0 -0
  137. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/model/transforms/test_autoreparam.py +0 -0
  138. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/__init__.py +0 -0
  139. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/_data/airpass.csv +0 -0
  140. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/_data/airpassangers.csv +0 -0
  141. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/_data/nile.csv +0 -0
  142. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
  143. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/core/__init__.py +0 -0
  144. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/core/test_representation.py +0 -0
  145. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/core/test_statespace_JAX.py +0 -0
  146. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/filters/__init__.py +0 -0
  147. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/filters/test_distributions.py +0 -0
  148. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/filters/test_kalman_filter.py +0 -0
  149. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/__init__.py +0 -0
  150. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/__init__.py +0 -0
  151. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/components/__init__.py +0 -0
  152. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/structural/conftest.py +0 -0
  153. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/test_ETS.py +0 -0
  154. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/models/test_utilities.py +0 -0
  155. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/shared_fixtures.py +0 -0
  156. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/statsmodel_local_level.py +0 -0
  157. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/test_utilities.py +0 -0
  158. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/statespace/utils/__init__.py +0 -0
  159. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_blackjax_smc.py +0 -0
  160. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_deserialize.py +0 -0
  161. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_linearmodel.py +0 -0
  162. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_model_builder.py +0 -0
  163. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_pathfinder.py +0 -0
  164. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_printing.py +0 -0
  165. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_prior.py +0 -0
  166. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_prior_from_trace.py +0 -0
  167. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/test_splines.py +0 -0
  168. {pymc_extras-0.4.0 → pymc_extras-0.4.1}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.4.0'
21
- __version_tuple__ = version_tuple = (0, 4, 0)
31
+ __version__ = version = '0.4.1'
32
+ __version_tuple__ = version_tuple = (0, 4, 1)
33
+
34
+ __commit_id__ = commit_id = None
@@ -18,7 +18,7 @@ import pymc as pm
18
18
 
19
19
  from numpy.typing import ArrayLike
20
20
 
21
- __all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
21
+ __all__ = ["discrete_histogram", "histogram_approximation", "quantile_histogram"]
22
22
 
23
23
 
24
24
  def quantile_histogram(
@@ -17,4 +17,4 @@ from pymc_extras.inference.laplace_approx.find_map import find_MAP
17
17
  from pymc_extras.inference.laplace_approx.laplace import fit_laplace
18
18
  from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19
19
 
20
- __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
20
+ __all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
@@ -166,7 +166,7 @@ def model_table(
166
166
 
167
167
  for var in group:
168
168
  var_name = var.name
169
- sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
169
+ sep = f"[b]{' ~' if (var in model.basic_RVs) else ' ='}[/b]"
170
170
  var_expr = variable_expression(model, var, truncate_deterministic)
171
171
  dims_expr = dims_expression(model, var)
172
172
  if dims_expr == "[]":
@@ -1,13 +1,13 @@
1
1
  from pymc_extras.statespace.core.compile import compile_statespace
2
2
  from pymc_extras.statespace.models import structural
3
3
  from pymc_extras.statespace.models.ETS import BayesianETS
4
- from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
4
+ from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
5
5
  from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
6
6
 
7
7
  __all__ = [
8
- "compile_statespace",
9
- "structural",
10
8
  "BayesianETS",
11
- "BayesianSARIMA",
9
+ "BayesianSARIMAX",
12
10
  "BayesianVARMAX",
11
+ "compile_statespace",
12
+ "structural",
13
13
  ]
@@ -4,4 +4,4 @@ from pymc_extras.statespace.core.representation import PytensorRepresentation
4
4
  from pymc_extras.statespace.core.statespace import PyMCStateSpace
5
5
  from pymc_extras.statespace.core.compile import compile_statespace
6
6
 
7
- __all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"]
7
+ __all__ = ["PyMCStateSpace", "PytensorRepresentation", "compile_statespace"]
@@ -60,7 +60,7 @@ FILTER_FACTORY = {
60
60
  def _validate_filter_arg(filter_arg):
61
61
  if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
62
62
  raise ValueError(
63
- f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}'
63
+ f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}"
64
64
  )
65
65
 
66
66
 
@@ -233,10 +233,9 @@ class PyMCStateSpace:
233
233
  self._fit_coords: dict[str, Sequence[str]] | None = None
234
234
  self._fit_dims: dict[str, Sequence[str]] | None = None
235
235
  self._fit_data: pt.TensorVariable | None = None
236
+ self._fit_exog_data: dict[str, dict] = {}
236
237
 
237
238
  self._needs_exog_data = None
238
- self._exog_names = []
239
- self._exog_data_info = {}
240
239
  self._name_to_variable = {}
241
240
  self._name_to_data = {}
242
241
 
@@ -671,7 +670,7 @@ class PyMCStateSpace:
671
670
  pymc_mod = modelcontext(None)
672
671
  for data_name in self.data_names:
673
672
  data = pymc_mod[data_name]
674
- self._exog_data_info[data_name] = {
673
+ self._fit_exog_data[data_name] = {
675
674
  "name": data_name,
676
675
  "value": data.get_value(),
677
676
  "dims": pymc_mod.named_vars_to_dims.get(data_name, None),
@@ -685,7 +684,7 @@ class PyMCStateSpace:
685
684
  --------
686
685
  .. code:: python
687
686
 
688
- ss_mod = pmss.BayesianSARIMA(order=(2, 0, 2), verbose=False, stationary_initialization=True)
687
+ ss_mod = pmss.BayesianSARIMAX(order=(2, 0, 2), verbose=False, stationary_initialization=True)
689
688
  with pm.Model():
690
689
  x0 = pm.Normal('x0', size=ss_mod.k_states)
691
690
  ar_params = pm.Normal('ar_params', size=ss_mod.p)
@@ -805,16 +804,16 @@ class PyMCStateSpace:
805
804
  states, covs = outputs[:4], outputs[4:]
806
805
 
807
806
  state_names = [
808
- "filtered_state",
809
- "predicted_state",
810
- "predicted_observed_state",
811
- "smoothed_state",
807
+ "filtered_states",
808
+ "predicted_states",
809
+ "predicted_observed_states",
810
+ "smoothed_states",
812
811
  ]
813
812
  cov_names = [
814
- "filtered_covariance",
815
- "predicted_covariance",
816
- "predicted_observed_covariance",
817
- "smoothed_covariance",
813
+ "filtered_covariances",
814
+ "predicted_covariances",
815
+ "predicted_observed_covariances",
816
+ "smoothed_covariances",
818
817
  ]
819
818
 
820
819
  with mod:
@@ -939,7 +938,7 @@ class PyMCStateSpace:
939
938
  all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
940
939
  self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
941
940
 
942
- obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
941
+ obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"]
943
942
  obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
944
943
 
945
944
  SequenceMvNormal(
@@ -1082,7 +1081,7 @@ class PyMCStateSpace:
1082
1081
 
1083
1082
  for name in self.data_names:
1084
1083
  if name not in pm_mod:
1085
- pm.Data(**self._exog_data_info[name])
1084
+ pm.Data(**self._fit_exog_data[name])
1086
1085
 
1087
1086
  self._insert_data_variables()
1088
1087
 
@@ -1229,7 +1228,7 @@ class PyMCStateSpace:
1229
1228
  method=mvn_method,
1230
1229
  )
1231
1230
 
1232
- obs_mu = (Z @ mu[..., None]).squeeze(-1)
1231
+ obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
1233
1232
  obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H
1234
1233
 
1235
1234
  SequenceMvNormal(
@@ -1351,7 +1350,7 @@ class PyMCStateSpace:
1351
1350
  self._insert_random_variables()
1352
1351
 
1353
1352
  for name in self.data_names:
1354
- pm.Data(**self._exog_data_info[name])
1353
+ pm.Data(**self._fit_exog_data[name])
1355
1354
 
1356
1355
  self._insert_data_variables()
1357
1356
 
@@ -1651,7 +1650,7 @@ class PyMCStateSpace:
1651
1650
  self._insert_random_variables()
1652
1651
 
1653
1652
  for name in self.data_names:
1654
- pm.Data(**self._exog_data_info[name])
1653
+ pm.Data(**self.data_info[name])
1655
1654
 
1656
1655
  self._insert_data_variables()
1657
1656
  matrices = self.unpack_statespace()
@@ -1678,6 +1677,78 @@ class PyMCStateSpace:
1678
1677
 
1679
1678
  return matrix_idata
1680
1679
 
1680
+ def sample_filter_outputs(
1681
+ self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
1682
+ ):
1683
+ if isinstance(filter_output_names, str):
1684
+ filter_output_names = [filter_output_names]
1685
+
1686
+ if filter_output_names is None:
1687
+ filter_output_names = list(FILTER_OUTPUT_DIMS.keys())
1688
+ else:
1689
+ unknown_filter_output_names = np.setdiff1d(
1690
+ filter_output_names, list(FILTER_OUTPUT_DIMS.keys())
1691
+ )
1692
+ if unknown_filter_output_names.size > 0:
1693
+ raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
1694
+ filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names]
1695
+
1696
+ compile_kwargs = kwargs.pop("compile_kwargs", {})
1697
+ compile_kwargs.setdefault("mode", self.mode)
1698
+
1699
+ with pm.Model(coords=self.coords) as m:
1700
+ self._build_dummy_graph()
1701
+ self._insert_random_variables()
1702
+
1703
+ if self.data_names:
1704
+ for name in self.data_names:
1705
+ pm.Data(**self._fit_exog_data[name])
1706
+
1707
+ self._insert_data_variables()
1708
+
1709
+ x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
1710
+ data = self._fit_data
1711
+
1712
+ obs_coords = m.coords.get(OBS_STATE_DIM, None)
1713
+
1714
+ data, nan_mask = register_data_with_pymc(
1715
+ data,
1716
+ n_obs=self.ssm.k_endog,
1717
+ obs_coords=obs_coords,
1718
+ register_data=True,
1719
+ )
1720
+
1721
+ filter_outputs = self.kalman_filter.build_graph(
1722
+ data,
1723
+ x0,
1724
+ P0,
1725
+ c,
1726
+ d,
1727
+ T,
1728
+ Z,
1729
+ R,
1730
+ H,
1731
+ Q,
1732
+ )
1733
+
1734
+ smoother_outputs = self.kalman_smoother.build_graph(
1735
+ T, R, Q, filter_outputs[0], filter_outputs[3]
1736
+ )
1737
+
1738
+ filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1739
+ for output in filter_outputs:
1740
+ if output.name in filter_output_names:
1741
+ dims = FILTER_OUTPUT_DIMS[output.name]
1742
+ pm.Deterministic(output.name, output, dims=dims)
1743
+
1744
+ with freeze_dims_and_data(m):
1745
+ return pm.sample_posterior_predictive(
1746
+ idata if group == "posterior" else idata.prior,
1747
+ var_names=filter_output_names,
1748
+ compile_kwargs=compile_kwargs,
1749
+ **kwargs,
1750
+ )
1751
+
1681
1752
  @staticmethod
1682
1753
  def _validate_forecast_args(
1683
1754
  time_index: pd.RangeIndex | pd.DatetimeIndex,
@@ -1774,7 +1845,7 @@ class PyMCStateSpace:
1774
1845
  }
1775
1846
 
1776
1847
  if self._needs_exog_data and scenario is None:
1777
- exog_str = ",".join(self._exog_names)
1848
+ exog_str = ",".join(self.data_names)
1778
1849
  suffix = "s" if len(exog_str) > 1 else ""
1779
1850
  raise ValueError(
1780
1851
  f"This model was fit using exogenous data. Forecasting cannot be performed without "
@@ -1783,7 +1854,7 @@ class PyMCStateSpace:
1783
1854
 
1784
1855
  if isinstance(scenario, dict):
1785
1856
  for name, data in scenario.items():
1786
- if name not in self._exog_names:
1857
+ if name not in self.data_names:
1787
1858
  raise ValueError(
1788
1859
  f"Scenario data provided for variable '{name}', which is not an exogenous variable "
1789
1860
  f"used to fit the model."
@@ -1824,12 +1895,12 @@ class PyMCStateSpace:
1824
1895
  # name should only be None on the first non-recursive call. We only arrive to this branch in that case
1825
1896
  # if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
1826
1897
  # needs to be set.
1827
- if len(self._exog_names) > 1:
1898
+ if len(self.data_names) > 1:
1828
1899
  raise ValueError(
1829
1900
  "Multiple exogenous variables were used to fit the model. Provide a dictionary of "
1830
1901
  "scenario data instead."
1831
1902
  )
1832
- name = self._exog_names[0]
1903
+ name = self.data_names[0]
1833
1904
 
1834
1905
  # Omit dataframe from this basic shape check so we can give more detailed information about missing columns
1835
1906
  # in the next check
@@ -2031,7 +2102,7 @@ class PyMCStateSpace:
2031
2102
  return scenario
2032
2103
 
2033
2104
  # This was already checked as valid
2034
- name = self._exog_names[0] if name is None else name
2105
+ name = self.data_names[0] if name is None else name
2035
2106
 
2036
2107
  # Small tidying up in the case we just have a single scenario that's already a dataframe.
2037
2108
  if isinstance(scenario, pd.DataFrame | pd.Series):
@@ -15,10 +15,15 @@ from pymc_extras.statespace.filters.utilities import (
15
15
  split_vars_into_seq_and_nonseq,
16
16
  stabilize,
17
17
  )
18
- from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
18
+ from pymc_extras.statespace.utils.constants import (
19
+ FILTER_OUTPUT_NAMES,
20
+ JITTER_DEFAULT,
21
+ MATRIX_NAMES,
22
+ MISSING_FILL,
23
+ )
19
24
 
20
25
  MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
21
- PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
26
+ PARAM_NAMES = MATRIX_NAMES[2:]
22
27
 
23
28
  assert_time_varying_dim_correct = Assert(
24
29
  "The first dimension of a time varying matrix (the time dimension) must be "
@@ -119,7 +124,7 @@ class BaseFilter(ABC):
119
124
  # There are always two outputs_info wedged between the seqs and non_seqs
120
125
  seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
121
126
  return_ordered = []
122
- for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
127
+ for name in PARAM_NAMES:
123
128
  if name in self.seq_names:
124
129
  idx = self.seq_names.index(name)
125
130
  return_ordered.append(seqs[idx])
@@ -253,28 +258,28 @@ class BaseFilter(ABC):
253
258
  )
254
259
 
255
260
  filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
256
- filtered_states.name = "filtered_states"
261
+ filtered_states.name = FILTER_OUTPUT_NAMES[0]
257
262
 
258
263
  predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
259
- predicted_states.name = "predicted_states"
260
-
261
- observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
262
- observed_states.name = "observed_states"
264
+ predicted_states.name = FILTER_OUTPUT_NAMES[1]
263
265
 
264
266
  filtered_covariances = pt.specify_shape(
265
267
  filtered_covariances, (n, self.n_states, self.n_states)
266
268
  )
267
- filtered_covariances.name = "filtered_covariances"
269
+ filtered_covariances.name = FILTER_OUTPUT_NAMES[2]
268
270
 
269
271
  predicted_covariances = pt.specify_shape(
270
272
  predicted_covariances, (n, self.n_states, self.n_states)
271
273
  )
272
- predicted_covariances.name = "predicted_covariances"
274
+ predicted_covariances.name = FILTER_OUTPUT_NAMES[3]
275
+
276
+ observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
277
+ observed_states.name = FILTER_OUTPUT_NAMES[4]
273
278
 
274
279
  observed_covariances = pt.specify_shape(
275
280
  observed_covariances, (n, self.n_endog, self.n_endog)
276
281
  )
277
- observed_covariances.name = "observed_covariances"
282
+ observed_covariances.name = FILTER_OUTPUT_NAMES[5]
278
283
 
279
284
  loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
280
285
  loglike_obs.name = "loglike_obs"