pymc-extras 0.3.1__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/PKG-INFO +5 -4
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/_version.py +16 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/conda-envs/environment-test.yml +4 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/histogram_utils.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/__init__.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/find_map.py +12 -5
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/idata.py +4 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/laplace.py +6 -4
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/pathfinder.py +1 -2
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/printing.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/__init__.py +4 -4
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/core/__init__.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/core/statespace.py +94 -23
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/kalman_filter.py +16 -11
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/models/SARIMAX.py +138 -74
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/models/VARMAX.py +248 -57
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/models/__init__.py +2 -2
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/__init__.py +21 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/autoregressive.py +213 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/cycle.py +325 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/level_trend.py +289 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/measurement_error.py +154 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/regression.py +257 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components/seasonality.py +628 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/core.py +919 -0
- pymc_extras-0.4.1/pymc_extras/statespace/models/structural/utils.py +16 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/models/utilities.py +285 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/constants.py +21 -18
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/data_tools.py +4 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pyproject.toml +4 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/__init__.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_find_map.py +34 -7
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_laplace.py +76 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/core/test_statespace.py +285 -27
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/core/test_statespace_JAX.py +1 -1
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/filters/test_distributions.py +6 -6
- pymc_extras-0.4.1/tests/statespace/models/structural/components/__init__.py +0 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_autoregressive.py +267 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_cycle.py +401 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_level_trend.py +283 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_measurement_error.py +74 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_regression.py +338 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/components/test_seasonality.py +716 -0
- pymc_extras-0.4.1/tests/statespace/models/structural/conftest.py +29 -0
- pymc_extras-0.3.1/tests/statespace/models/test_structural.py → pymc_extras-0.4.1/tests/statespace/models/structural/test_against_statsmodels.py +34 -339
- pymc_extras-0.4.1/tests/statespace/models/structural/test_core.py +197 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/models/test_SARIMAX.py +70 -17
- pymc_extras-0.4.1/tests/statespace/models/test_VARMAX.py +545 -0
- pymc_extras-0.4.1/tests/statespace/models/test_utilities.py +298 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/test_utilities.py +3 -2
- pymc_extras-0.4.1/tests/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/utils/test_coord_assignment.py +7 -7
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_histogram_approximation.py +2 -2
- pymc_extras-0.4.1/tests/utils.py +0 -0
- pymc_extras-0.3.1/pymc_extras/statespace/models/structural.py +0 -1679
- pymc_extras-0.3.1/tests/statespace/models/test_VARMAX.py +0 -190
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/.gitignore +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/.gitpod.yml +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/.pre-commit-config.yaml +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/.readthedocs.yaml +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/LICENSE +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/README.md +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/codecov.yml +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/.nojekyll +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/Makefile +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/_templates/autosummary/base.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/_templates/autosummary/class.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/api_reference.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/conf.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/index.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/make.bat +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/statespace/core.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/statespace/filters.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/statespace/models/structural.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/docs/statespace/models.rst +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/deserialize.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/__init__.py +5 -5
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/timeseries.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/transforms/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/distributions/transforms/partial_order.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/fit.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/laplace_approx/scipy_interface.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/marginal/distributions.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/marginal/graph_analysis.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/marginal/marginal_model.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/prior.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/core/compile.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/core/representation.py +8 -8
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/__init__.py +3 -3
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/distributions.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.3.1/pymc_extras/statespace/utils → pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components}/__init__.py +0 -0
- {pymc_extras-0.3.1/tests/inference → pymc_extras-0.4.1/pymc_extras/statespace/utils}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/utils/model_equivalence.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/conftest.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/distributions/test_transform.py +0 -0
- {pymc_extras-0.3.1/tests/inference/laplace_approx → pymc_extras-0.4.1/tests/inference}/__init__.py +0 -0
- {pymc_extras-0.3.1/tests/model → pymc_extras-0.4.1/tests/inference/laplace_approx}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_idata.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
- {pymc_extras-0.3.1/tests/model/marginal → pymc_extras-0.4.1/tests/model}/__init__.py +0 -0
- {pymc_extras-0.3.1/tests/statespace → pymc_extras-0.4.1/tests/model/marginal}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/model/marginal/test_distributions.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/model/marginal/test_graph_analysis.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/model/marginal/test_marginal_model.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/model/transforms/test_autoreparam.py +0 -0
- {pymc_extras-0.3.1/tests/statespace/core → pymc_extras-0.4.1/tests/statespace}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/_data/airpass.csv +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/_data/airpassangers.csv +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/_data/nile.csv +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
- {pymc_extras-0.3.1/tests/statespace/filters → pymc_extras-0.4.1/tests/statespace/core}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/core/test_representation.py +0 -0
- {pymc_extras-0.3.1/tests/statespace/models → pymc_extras-0.4.1/tests/statespace/filters}/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/filters/test_kalman_filter.py +0 -0
- {pymc_extras-0.3.1/tests/statespace/utils → pymc_extras-0.4.1/tests/statespace/models}/__init__.py +0 -0
- /pymc_extras-0.3.1/tests/utils.py → /pymc_extras-0.4.1/tests/statespace/models/structural/__init__.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/models/test_ETS.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/shared_fixtures.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/statespace/statsmodel_local_level.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_deserialize.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_pathfinder.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_printing.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_prior.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.3.1 → pymc_extras-0.4.1}/tests/test_splines.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Project-URL: Documentation, https://pymc-extras.readthedocs.io/
|
|
6
6
|
Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
|
|
@@ -232,9 +232,11 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
232
232
|
Classifier: Topic :: Scientific/Engineering
|
|
233
233
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
234
234
|
Requires-Python: >=3.11
|
|
235
|
-
Requires-Dist: better-optimize>=0.1.
|
|
235
|
+
Requires-Dist: better-optimize>=0.1.5
|
|
236
|
+
Requires-Dist: preliz>=0.20.0
|
|
236
237
|
Requires-Dist: pydantic>=2.0.0
|
|
237
|
-
Requires-Dist: pymc>=5.
|
|
238
|
+
Requires-Dist: pymc>=5.24.1
|
|
239
|
+
Requires-Dist: pytensor>=2.31.4
|
|
238
240
|
Requires-Dist: scikit-learn
|
|
239
241
|
Provides-Extra: complete
|
|
240
242
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
|
|
@@ -245,7 +247,6 @@ Requires-Dist: xhistogram; extra == 'dask-histogram'
|
|
|
245
247
|
Provides-Extra: dev
|
|
246
248
|
Requires-Dist: blackjax; extra == 'dev'
|
|
247
249
|
Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
|
|
248
|
-
Requires-Dist: preliz>=0.5.0; extra == 'dev'
|
|
249
250
|
Requires-Dist: pytest-mock; extra == 'dev'
|
|
250
251
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
251
252
|
Requires-Dist: statsmodels; extra == 'dev'
|
|
@@ -1,7 +1,14 @@
|
|
|
1
1
|
# file generated by setuptools-scm
|
|
2
2
|
# don't change, don't track in version control
|
|
3
3
|
|
|
4
|
-
__all__ = [
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
5
12
|
|
|
6
13
|
TYPE_CHECKING = False
|
|
7
14
|
if TYPE_CHECKING:
|
|
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
|
|
|
9
16
|
from typing import Union
|
|
10
17
|
|
|
11
18
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
12
20
|
else:
|
|
13
21
|
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
14
23
|
|
|
15
24
|
version: str
|
|
16
25
|
__version__: str
|
|
17
26
|
__version_tuple__: VERSION_TUPLE
|
|
18
27
|
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
19
30
|
|
|
20
|
-
__version__ = version = '0.
|
|
21
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.4.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 4, 1)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
@@ -3,9 +3,10 @@ channels:
|
|
|
3
3
|
- conda-forge
|
|
4
4
|
- nodefaults
|
|
5
5
|
dependencies:
|
|
6
|
-
- pymc>=5.
|
|
6
|
+
- pymc>=5.24.1
|
|
7
|
+
- pytensor>=2.31.4
|
|
7
8
|
- scikit-learn
|
|
8
|
-
- better-optimize>=0.1.
|
|
9
|
+
- better-optimize>=0.1.5
|
|
9
10
|
- dask<2025.1.1
|
|
10
11
|
- xhistogram
|
|
11
12
|
- statsmodels
|
|
@@ -13,7 +14,7 @@ dependencies:
|
|
|
13
14
|
- pytest
|
|
14
15
|
- pytest-cov
|
|
15
16
|
- pydantic>=2.0.0
|
|
16
|
-
- preliz>=0.
|
|
17
|
+
- preliz>=0.20.0
|
|
17
18
|
- pip
|
|
18
19
|
- pip:
|
|
19
20
|
- jax
|
|
@@ -18,7 +18,7 @@ import pymc as pm
|
|
|
18
18
|
|
|
19
19
|
from numpy.typing import ArrayLike
|
|
20
20
|
|
|
21
|
-
__all__ = ["
|
|
21
|
+
__all__ = ["discrete_histogram", "histogram_approximation", "quantile_histogram"]
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def quantile_histogram(
|
|
@@ -17,4 +17,4 @@ from pymc_extras.inference.laplace_approx.find_map import find_MAP
|
|
|
17
17
|
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
|
|
18
18
|
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
|
|
19
19
|
|
|
20
|
-
__all__ = ["
|
|
20
|
+
__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
|
|
@@ -326,7 +326,7 @@ def find_MAP(
|
|
|
326
326
|
)
|
|
327
327
|
|
|
328
328
|
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
329
|
-
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
|
|
329
|
+
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
330
330
|
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
|
|
331
331
|
DictToArrayBijection.rmap(raveled_optimized)
|
|
332
332
|
)
|
|
@@ -335,13 +335,20 @@ def find_MAP(
|
|
|
335
335
|
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
|
|
336
336
|
}
|
|
337
337
|
|
|
338
|
-
idata = map_results_to_inference_data(
|
|
339
|
-
|
|
338
|
+
idata = map_results_to_inference_data(
|
|
339
|
+
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
idata = add_fit_to_inference_data(
|
|
343
|
+
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
|
|
344
|
+
)
|
|
345
|
+
|
|
340
346
|
idata = add_optimizer_result_to_inference_data(
|
|
341
|
-
idata, optimizer_result, method, raveled_optimized, model
|
|
347
|
+
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
|
|
342
348
|
)
|
|
349
|
+
|
|
343
350
|
idata = add_data_to_inference_data(
|
|
344
|
-
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
351
|
+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
345
352
|
)
|
|
346
353
|
|
|
347
354
|
return idata
|
|
@@ -59,6 +59,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
|
|
|
59
59
|
def map_results_to_inference_data(
|
|
60
60
|
map_point: dict[str, float | int | np.ndarray],
|
|
61
61
|
model: pm.Model | None = None,
|
|
62
|
+
include_transformed: bool = True,
|
|
62
63
|
):
|
|
63
64
|
"""
|
|
64
65
|
Add the MAP point to an InferenceData object in the posterior group.
|
|
@@ -68,13 +69,13 @@ def map_results_to_inference_data(
|
|
|
68
69
|
|
|
69
70
|
Parameters
|
|
70
71
|
----------
|
|
71
|
-
idata: az.InferenceData
|
|
72
|
-
An InferenceData object to which the MAP point will be added.
|
|
73
72
|
map_point: dict
|
|
74
73
|
A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
|
|
75
74
|
the values should be the corresponding MAP estimates.
|
|
76
75
|
model: Model, optional
|
|
77
76
|
A PyMC model. If None, the model is taken from the current model context.
|
|
77
|
+
include_transformed: bool
|
|
78
|
+
Whether to return transformed (unconstrained) variables in the constrained_posterior group. Default is True.
|
|
78
79
|
|
|
79
80
|
Returns
|
|
80
81
|
-------
|
|
@@ -118,7 +119,7 @@ def map_results_to_inference_data(
|
|
|
118
119
|
dims=dims,
|
|
119
120
|
)
|
|
120
121
|
|
|
121
|
-
if unconstrained_names:
|
|
122
|
+
if unconstrained_names and include_transformed:
|
|
122
123
|
unconstrained_posterior = az.from_dict(
|
|
123
124
|
posterior={
|
|
124
125
|
k: np.expand_dims(v, (0, 1))
|
|
@@ -302,7 +302,7 @@ def fit_laplace(
|
|
|
302
302
|
----------
|
|
303
303
|
model : pm.Model
|
|
304
304
|
The PyMC model to be fit. If None, the current model context is used.
|
|
305
|
-
|
|
305
|
+
optimize_method : str
|
|
306
306
|
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
307
307
|
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
308
308
|
|
|
@@ -441,9 +441,11 @@ def fit_laplace(
|
|
|
441
441
|
.rename({"temp_chain": "chain", "temp_draw": "draw"})
|
|
442
442
|
)
|
|
443
443
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
444
|
+
if include_transformed:
|
|
445
|
+
idata.unconstrained_posterior = unstack_laplace_draws(
|
|
446
|
+
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
|
|
447
|
+
)
|
|
448
|
+
|
|
447
449
|
idata.posterior = new_posterior.drop_vars(
|
|
448
450
|
["laplace_approximation", "unpacked_variable_names"]
|
|
449
451
|
)
|
|
@@ -38,16 +38,15 @@ from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
|
38
38
|
from pymc.initial_point import make_initial_point_fn
|
|
39
39
|
from pymc.model import modelcontext
|
|
40
40
|
from pymc.model.core import Point
|
|
41
|
+
from pymc.progress_bar import CustomProgress, default_progress_theme
|
|
41
42
|
from pymc.pytensorf import (
|
|
42
43
|
compile,
|
|
43
44
|
find_rng_nodes,
|
|
44
45
|
reseed_rngs,
|
|
45
46
|
)
|
|
46
47
|
from pymc.util import (
|
|
47
|
-
CustomProgress,
|
|
48
48
|
RandomSeed,
|
|
49
49
|
_get_seeds_per_chain,
|
|
50
|
-
default_progress_theme,
|
|
51
50
|
get_default_varnames,
|
|
52
51
|
)
|
|
53
52
|
from pytensor.compile.function.types import Function
|
|
@@ -166,7 +166,7 @@ def model_table(
|
|
|
166
166
|
|
|
167
167
|
for var in group:
|
|
168
168
|
var_name = var.name
|
|
169
|
-
sep = f
|
|
169
|
+
sep = f"[b]{' ~' if (var in model.basic_RVs) else ' ='}[/b]"
|
|
170
170
|
var_expr = variable_expression(model, var, truncate_deterministic)
|
|
171
171
|
dims_expr = dims_expression(model, var)
|
|
172
172
|
if dims_expr == "[]":
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from pymc_extras.statespace.core.compile import compile_statespace
|
|
2
2
|
from pymc_extras.statespace.models import structural
|
|
3
3
|
from pymc_extras.statespace.models.ETS import BayesianETS
|
|
4
|
-
from pymc_extras.statespace.models.SARIMAX import
|
|
4
|
+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
|
|
5
5
|
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
|
|
6
6
|
|
|
7
7
|
__all__ = [
|
|
8
|
-
"compile_statespace",
|
|
9
|
-
"structural",
|
|
10
8
|
"BayesianETS",
|
|
11
|
-
"
|
|
9
|
+
"BayesianSARIMAX",
|
|
12
10
|
"BayesianVARMAX",
|
|
11
|
+
"compile_statespace",
|
|
12
|
+
"structural",
|
|
13
13
|
]
|
|
@@ -4,4 +4,4 @@ from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
|
4
4
|
from pymc_extras.statespace.core.statespace import PyMCStateSpace
|
|
5
5
|
from pymc_extras.statespace.core.compile import compile_statespace
|
|
6
6
|
|
|
7
|
-
__all__ = ["
|
|
7
|
+
__all__ = ["PyMCStateSpace", "PytensorRepresentation", "compile_statespace"]
|
|
@@ -60,7 +60,7 @@ FILTER_FACTORY = {
|
|
|
60
60
|
def _validate_filter_arg(filter_arg):
|
|
61
61
|
if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
|
|
62
62
|
raise ValueError(
|
|
63
|
-
f
|
|
63
|
+
f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}"
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
|
|
@@ -233,10 +233,9 @@ class PyMCStateSpace:
|
|
|
233
233
|
self._fit_coords: dict[str, Sequence[str]] | None = None
|
|
234
234
|
self._fit_dims: dict[str, Sequence[str]] | None = None
|
|
235
235
|
self._fit_data: pt.TensorVariable | None = None
|
|
236
|
+
self._fit_exog_data: dict[str, dict] = {}
|
|
236
237
|
|
|
237
238
|
self._needs_exog_data = None
|
|
238
|
-
self._exog_names = []
|
|
239
|
-
self._exog_data_info = {}
|
|
240
239
|
self._name_to_variable = {}
|
|
241
240
|
self._name_to_data = {}
|
|
242
241
|
|
|
@@ -671,7 +670,7 @@ class PyMCStateSpace:
|
|
|
671
670
|
pymc_mod = modelcontext(None)
|
|
672
671
|
for data_name in self.data_names:
|
|
673
672
|
data = pymc_mod[data_name]
|
|
674
|
-
self.
|
|
673
|
+
self._fit_exog_data[data_name] = {
|
|
675
674
|
"name": data_name,
|
|
676
675
|
"value": data.get_value(),
|
|
677
676
|
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
|
|
@@ -685,7 +684,7 @@ class PyMCStateSpace:
|
|
|
685
684
|
--------
|
|
686
685
|
.. code:: python
|
|
687
686
|
|
|
688
|
-
ss_mod = pmss.
|
|
687
|
+
ss_mod = pmss.BayesianSARIMAX(order=(2, 0, 2), verbose=False, stationary_initialization=True)
|
|
689
688
|
with pm.Model():
|
|
690
689
|
x0 = pm.Normal('x0', size=ss_mod.k_states)
|
|
691
690
|
ar_params = pm.Normal('ar_params', size=ss_mod.p)
|
|
@@ -805,16 +804,16 @@ class PyMCStateSpace:
|
|
|
805
804
|
states, covs = outputs[:4], outputs[4:]
|
|
806
805
|
|
|
807
806
|
state_names = [
|
|
808
|
-
"
|
|
809
|
-
"
|
|
810
|
-
"
|
|
811
|
-
"
|
|
807
|
+
"filtered_states",
|
|
808
|
+
"predicted_states",
|
|
809
|
+
"predicted_observed_states",
|
|
810
|
+
"smoothed_states",
|
|
812
811
|
]
|
|
813
812
|
cov_names = [
|
|
814
|
-
"
|
|
815
|
-
"
|
|
816
|
-
"
|
|
817
|
-
"
|
|
813
|
+
"filtered_covariances",
|
|
814
|
+
"predicted_covariances",
|
|
815
|
+
"predicted_observed_covariances",
|
|
816
|
+
"smoothed_covariances",
|
|
818
817
|
]
|
|
819
818
|
|
|
820
819
|
with mod:
|
|
@@ -939,7 +938,7 @@ class PyMCStateSpace:
|
|
|
939
938
|
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
|
|
940
939
|
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
|
|
941
940
|
|
|
942
|
-
obs_dims = FILTER_OUTPUT_DIMS["
|
|
941
|
+
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"]
|
|
943
942
|
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
|
|
944
943
|
|
|
945
944
|
SequenceMvNormal(
|
|
@@ -1082,7 +1081,7 @@ class PyMCStateSpace:
|
|
|
1082
1081
|
|
|
1083
1082
|
for name in self.data_names:
|
|
1084
1083
|
if name not in pm_mod:
|
|
1085
|
-
pm.Data(**self.
|
|
1084
|
+
pm.Data(**self._fit_exog_data[name])
|
|
1086
1085
|
|
|
1087
1086
|
self._insert_data_variables()
|
|
1088
1087
|
|
|
@@ -1229,7 +1228,7 @@ class PyMCStateSpace:
|
|
|
1229
1228
|
method=mvn_method,
|
|
1230
1229
|
)
|
|
1231
1230
|
|
|
1232
|
-
obs_mu = (Z @ mu[..., None]).squeeze(-1)
|
|
1231
|
+
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
|
|
1233
1232
|
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H
|
|
1234
1233
|
|
|
1235
1234
|
SequenceMvNormal(
|
|
@@ -1351,7 +1350,7 @@ class PyMCStateSpace:
|
|
|
1351
1350
|
self._insert_random_variables()
|
|
1352
1351
|
|
|
1353
1352
|
for name in self.data_names:
|
|
1354
|
-
pm.Data(**self.
|
|
1353
|
+
pm.Data(**self._fit_exog_data[name])
|
|
1355
1354
|
|
|
1356
1355
|
self._insert_data_variables()
|
|
1357
1356
|
|
|
@@ -1651,7 +1650,7 @@ class PyMCStateSpace:
|
|
|
1651
1650
|
self._insert_random_variables()
|
|
1652
1651
|
|
|
1653
1652
|
for name in self.data_names:
|
|
1654
|
-
pm.Data(**self.
|
|
1653
|
+
pm.Data(**self.data_info[name])
|
|
1655
1654
|
|
|
1656
1655
|
self._insert_data_variables()
|
|
1657
1656
|
matrices = self.unpack_statespace()
|
|
@@ -1678,6 +1677,78 @@ class PyMCStateSpace:
|
|
|
1678
1677
|
|
|
1679
1678
|
return matrix_idata
|
|
1680
1679
|
|
|
1680
|
+
def sample_filter_outputs(
|
|
1681
|
+
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
|
|
1682
|
+
):
|
|
1683
|
+
if isinstance(filter_output_names, str):
|
|
1684
|
+
filter_output_names = [filter_output_names]
|
|
1685
|
+
|
|
1686
|
+
if filter_output_names is None:
|
|
1687
|
+
filter_output_names = list(FILTER_OUTPUT_DIMS.keys())
|
|
1688
|
+
else:
|
|
1689
|
+
unknown_filter_output_names = np.setdiff1d(
|
|
1690
|
+
filter_output_names, list(FILTER_OUTPUT_DIMS.keys())
|
|
1691
|
+
)
|
|
1692
|
+
if unknown_filter_output_names.size > 0:
|
|
1693
|
+
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
|
|
1694
|
+
filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names]
|
|
1695
|
+
|
|
1696
|
+
compile_kwargs = kwargs.pop("compile_kwargs", {})
|
|
1697
|
+
compile_kwargs.setdefault("mode", self.mode)
|
|
1698
|
+
|
|
1699
|
+
with pm.Model(coords=self.coords) as m:
|
|
1700
|
+
self._build_dummy_graph()
|
|
1701
|
+
self._insert_random_variables()
|
|
1702
|
+
|
|
1703
|
+
if self.data_names:
|
|
1704
|
+
for name in self.data_names:
|
|
1705
|
+
pm.Data(**self._fit_exog_data[name])
|
|
1706
|
+
|
|
1707
|
+
self._insert_data_variables()
|
|
1708
|
+
|
|
1709
|
+
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
|
|
1710
|
+
data = self._fit_data
|
|
1711
|
+
|
|
1712
|
+
obs_coords = m.coords.get(OBS_STATE_DIM, None)
|
|
1713
|
+
|
|
1714
|
+
data, nan_mask = register_data_with_pymc(
|
|
1715
|
+
data,
|
|
1716
|
+
n_obs=self.ssm.k_endog,
|
|
1717
|
+
obs_coords=obs_coords,
|
|
1718
|
+
register_data=True,
|
|
1719
|
+
)
|
|
1720
|
+
|
|
1721
|
+
filter_outputs = self.kalman_filter.build_graph(
|
|
1722
|
+
data,
|
|
1723
|
+
x0,
|
|
1724
|
+
P0,
|
|
1725
|
+
c,
|
|
1726
|
+
d,
|
|
1727
|
+
T,
|
|
1728
|
+
Z,
|
|
1729
|
+
R,
|
|
1730
|
+
H,
|
|
1731
|
+
Q,
|
|
1732
|
+
)
|
|
1733
|
+
|
|
1734
|
+
smoother_outputs = self.kalman_smoother.build_graph(
|
|
1735
|
+
T, R, Q, filter_outputs[0], filter_outputs[3]
|
|
1736
|
+
)
|
|
1737
|
+
|
|
1738
|
+
filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
|
|
1739
|
+
for output in filter_outputs:
|
|
1740
|
+
if output.name in filter_output_names:
|
|
1741
|
+
dims = FILTER_OUTPUT_DIMS[output.name]
|
|
1742
|
+
pm.Deterministic(output.name, output, dims=dims)
|
|
1743
|
+
|
|
1744
|
+
with freeze_dims_and_data(m):
|
|
1745
|
+
return pm.sample_posterior_predictive(
|
|
1746
|
+
idata if group == "posterior" else idata.prior,
|
|
1747
|
+
var_names=filter_output_names,
|
|
1748
|
+
compile_kwargs=compile_kwargs,
|
|
1749
|
+
**kwargs,
|
|
1750
|
+
)
|
|
1751
|
+
|
|
1681
1752
|
@staticmethod
|
|
1682
1753
|
def _validate_forecast_args(
|
|
1683
1754
|
time_index: pd.RangeIndex | pd.DatetimeIndex,
|
|
@@ -1774,7 +1845,7 @@ class PyMCStateSpace:
|
|
|
1774
1845
|
}
|
|
1775
1846
|
|
|
1776
1847
|
if self._needs_exog_data and scenario is None:
|
|
1777
|
-
exog_str = ",".join(self.
|
|
1848
|
+
exog_str = ",".join(self.data_names)
|
|
1778
1849
|
suffix = "s" if len(exog_str) > 1 else ""
|
|
1779
1850
|
raise ValueError(
|
|
1780
1851
|
f"This model was fit using exogenous data. Forecasting cannot be performed without "
|
|
@@ -1783,7 +1854,7 @@ class PyMCStateSpace:
|
|
|
1783
1854
|
|
|
1784
1855
|
if isinstance(scenario, dict):
|
|
1785
1856
|
for name, data in scenario.items():
|
|
1786
|
-
if name not in self.
|
|
1857
|
+
if name not in self.data_names:
|
|
1787
1858
|
raise ValueError(
|
|
1788
1859
|
f"Scenario data provided for variable '{name}', which is not an exogenous variable "
|
|
1789
1860
|
f"used to fit the model."
|
|
@@ -1824,12 +1895,12 @@ class PyMCStateSpace:
|
|
|
1824
1895
|
# name should only be None on the first non-recursive call. We only arrive to this branch in that case
|
|
1825
1896
|
# if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
|
|
1826
1897
|
# needs to be set.
|
|
1827
|
-
if len(self.
|
|
1898
|
+
if len(self.data_names) > 1:
|
|
1828
1899
|
raise ValueError(
|
|
1829
1900
|
"Multiple exogenous variables were used to fit the model. Provide a dictionary of "
|
|
1830
1901
|
"scenario data instead."
|
|
1831
1902
|
)
|
|
1832
|
-
name = self.
|
|
1903
|
+
name = self.data_names[0]
|
|
1833
1904
|
|
|
1834
1905
|
# Omit dataframe from this basic shape check so we can give more detailed information about missing columns
|
|
1835
1906
|
# in the next check
|
|
@@ -2031,7 +2102,7 @@ class PyMCStateSpace:
|
|
|
2031
2102
|
return scenario
|
|
2032
2103
|
|
|
2033
2104
|
# This was already checked as valid
|
|
2034
|
-
name = self.
|
|
2105
|
+
name = self.data_names[0] if name is None else name
|
|
2035
2106
|
|
|
2036
2107
|
# Small tidying up in the case we just have a single scenario that's already a dataframe.
|
|
2037
2108
|
if isinstance(scenario, pd.DataFrame | pd.Series):
|
|
@@ -15,10 +15,15 @@ from pymc_extras.statespace.filters.utilities import (
|
|
|
15
15
|
split_vars_into_seq_and_nonseq,
|
|
16
16
|
stabilize,
|
|
17
17
|
)
|
|
18
|
-
from pymc_extras.statespace.utils.constants import
|
|
18
|
+
from pymc_extras.statespace.utils.constants import (
|
|
19
|
+
FILTER_OUTPUT_NAMES,
|
|
20
|
+
JITTER_DEFAULT,
|
|
21
|
+
MATRIX_NAMES,
|
|
22
|
+
MISSING_FILL,
|
|
23
|
+
)
|
|
19
24
|
|
|
20
25
|
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
|
|
21
|
-
PARAM_NAMES = [
|
|
26
|
+
PARAM_NAMES = MATRIX_NAMES[2:]
|
|
22
27
|
|
|
23
28
|
assert_time_varying_dim_correct = Assert(
|
|
24
29
|
"The first dimension of a time varying matrix (the time dimension) must be "
|
|
@@ -119,7 +124,7 @@ class BaseFilter(ABC):
|
|
|
119
124
|
# There are always two outputs_info wedged between the seqs and non_seqs
|
|
120
125
|
seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
|
|
121
126
|
return_ordered = []
|
|
122
|
-
for name in
|
|
127
|
+
for name in PARAM_NAMES:
|
|
123
128
|
if name in self.seq_names:
|
|
124
129
|
idx = self.seq_names.index(name)
|
|
125
130
|
return_ordered.append(seqs[idx])
|
|
@@ -253,28 +258,28 @@ class BaseFilter(ABC):
|
|
|
253
258
|
)
|
|
254
259
|
|
|
255
260
|
filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
|
|
256
|
-
filtered_states.name =
|
|
261
|
+
filtered_states.name = FILTER_OUTPUT_NAMES[0]
|
|
257
262
|
|
|
258
263
|
predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
|
|
259
|
-
predicted_states.name =
|
|
260
|
-
|
|
261
|
-
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
|
|
262
|
-
observed_states.name = "observed_states"
|
|
264
|
+
predicted_states.name = FILTER_OUTPUT_NAMES[1]
|
|
263
265
|
|
|
264
266
|
filtered_covariances = pt.specify_shape(
|
|
265
267
|
filtered_covariances, (n, self.n_states, self.n_states)
|
|
266
268
|
)
|
|
267
|
-
filtered_covariances.name =
|
|
269
|
+
filtered_covariances.name = FILTER_OUTPUT_NAMES[2]
|
|
268
270
|
|
|
269
271
|
predicted_covariances = pt.specify_shape(
|
|
270
272
|
predicted_covariances, (n, self.n_states, self.n_states)
|
|
271
273
|
)
|
|
272
|
-
predicted_covariances.name =
|
|
274
|
+
predicted_covariances.name = FILTER_OUTPUT_NAMES[3]
|
|
275
|
+
|
|
276
|
+
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
|
|
277
|
+
observed_states.name = FILTER_OUTPUT_NAMES[4]
|
|
273
278
|
|
|
274
279
|
observed_covariances = pt.specify_shape(
|
|
275
280
|
observed_covariances, (n, self.n_endog, self.n_endog)
|
|
276
281
|
)
|
|
277
|
-
observed_covariances.name =
|
|
282
|
+
observed_covariances.name = FILTER_OUTPUT_NAMES[5]
|
|
278
283
|
|
|
279
284
|
loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
|
|
280
285
|
loglike_obs.name = "loglike_obs"
|