pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,156 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pymc as pm
|
|
5
|
-
import pytensor
|
|
6
|
-
import pytensor.tensor as pt
|
|
7
|
-
import pytest
|
|
8
|
-
|
|
9
|
-
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
10
|
-
|
|
11
|
-
from pymc_extras.statespace.utils.constants import (
|
|
12
|
-
FILTER_OUTPUT_NAMES,
|
|
13
|
-
MATRIX_NAMES,
|
|
14
|
-
SMOOTHER_OUTPUT_NAMES,
|
|
15
|
-
)
|
|
16
|
-
from tests.statespace.test_statespace import ( # pylint: disable=unused-import
|
|
17
|
-
exog_ss_mod,
|
|
18
|
-
ss_mod,
|
|
19
|
-
)
|
|
20
|
-
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
21
|
-
rng,
|
|
22
|
-
)
|
|
23
|
-
from tests.statespace.utilities.test_helpers import load_nile_test_data
|
|
24
|
-
|
|
25
|
-
pytest.importorskip("jax")
|
|
26
|
-
pytest.importorskip("numpyro")
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
floatX = pytensor.config.floatX
|
|
30
|
-
nile = load_nile_test_data()
|
|
31
|
-
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@pytest.fixture(scope="session")
|
|
35
|
-
def pymc_mod(ss_mod):
|
|
36
|
-
with pm.Model(coords=ss_mod.coords) as pymc_mod:
|
|
37
|
-
rho = pm.Beta("rho", 1, 1)
|
|
38
|
-
zeta = pm.Deterministic("zeta", 1 - rho)
|
|
39
|
-
|
|
40
|
-
ss_mod.build_statespace_graph(
|
|
41
|
-
data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True
|
|
42
|
-
)
|
|
43
|
-
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
|
|
44
|
-
for name, matrix in zip(names, ss_mod.unpack_statespace()):
|
|
45
|
-
pm.Deterministic(name, matrix)
|
|
46
|
-
|
|
47
|
-
return pymc_mod
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@pytest.fixture(scope="session")
|
|
51
|
-
def exog_pymc_mod(exog_ss_mod, rng):
|
|
52
|
-
y = rng.normal(size=(100, 1)).astype(floatX)
|
|
53
|
-
X = rng.normal(size=(100, 3)).astype(floatX)
|
|
54
|
-
|
|
55
|
-
with pm.Model(coords=exog_ss_mod.coords) as m:
|
|
56
|
-
exog_data = pm.Data("data_exog", X)
|
|
57
|
-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
|
|
58
|
-
P0_sigma = pm.Exponential("P0_sigma", 1)
|
|
59
|
-
P0 = pm.Deterministic(
|
|
60
|
-
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
|
|
61
|
-
)
|
|
62
|
-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
|
|
63
|
-
|
|
64
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
|
|
65
|
-
exog_ss_mod.build_statespace_graph(y, mode="JAX")
|
|
66
|
-
|
|
67
|
-
return m
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@pytest.fixture(scope="session")
|
|
71
|
-
def idata(pymc_mod, rng):
|
|
72
|
-
with warnings.catch_warnings():
|
|
73
|
-
warnings.simplefilter("ignore")
|
|
74
|
-
with pymc_mod:
|
|
75
|
-
idata = pm.sample(
|
|
76
|
-
draws=10,
|
|
77
|
-
tune=1,
|
|
78
|
-
chains=1,
|
|
79
|
-
random_seed=rng,
|
|
80
|
-
nuts_sampler="numpyro",
|
|
81
|
-
progressbar=False,
|
|
82
|
-
)
|
|
83
|
-
with freeze_dims_and_data(pymc_mod):
|
|
84
|
-
idata_prior = pm.sample_prior_predictive(
|
|
85
|
-
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
idata.extend(idata_prior)
|
|
89
|
-
return idata
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@pytest.fixture(scope="session")
|
|
93
|
-
def idata_exog(exog_pymc_mod, rng):
|
|
94
|
-
with warnings.catch_warnings():
|
|
95
|
-
warnings.simplefilter("ignore")
|
|
96
|
-
|
|
97
|
-
with exog_pymc_mod:
|
|
98
|
-
idata = pm.sample(
|
|
99
|
-
draws=10,
|
|
100
|
-
tune=1,
|
|
101
|
-
chains=1,
|
|
102
|
-
random_seed=rng,
|
|
103
|
-
nuts_sampler="numpyro",
|
|
104
|
-
progressbar=False,
|
|
105
|
-
)
|
|
106
|
-
with freeze_dims_and_data(pymc_mod):
|
|
107
|
-
idata_prior = pm.sample_prior_predictive(
|
|
108
|
-
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
idata.extend(idata_prior)
|
|
112
|
-
return idata
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@pytest.mark.parametrize("group", ["posterior", "prior"])
|
|
116
|
-
@pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
|
|
117
|
-
def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
|
|
118
|
-
assert not np.any(np.isnan(idata[group][matrix].values))
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@pytest.mark.parametrize("group", ["prior", "posterior"])
|
|
122
|
-
@pytest.mark.parametrize("kind", ["conditional", "unconditional"])
|
|
123
|
-
def test_sampling_methods(group, kind, ss_mod, idata, rng):
|
|
124
|
-
assert ss_mod._fit_mode == "JAX"
|
|
125
|
-
|
|
126
|
-
f = getattr(ss_mod, f"sample_{kind}_{group}")
|
|
127
|
-
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
|
|
128
|
-
test_idata = f(idata, random_seed=rng)
|
|
129
|
-
|
|
130
|
-
if kind == "conditional":
|
|
131
|
-
for output in ["filtered", "predicted", "smoothed"]:
|
|
132
|
-
assert f"{output}_{group}" in test_idata
|
|
133
|
-
assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
|
|
134
|
-
assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
|
|
135
|
-
|
|
136
|
-
if kind == "unconditional":
|
|
137
|
-
for output in ["latent", "observed"]:
|
|
138
|
-
assert f"{group}_{output}" in test_idata
|
|
139
|
-
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
@pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
|
|
143
|
-
def test_forecast(filter_output, ss_mod, idata, rng):
|
|
144
|
-
time_idx = idata.posterior.coords["time"].values
|
|
145
|
-
|
|
146
|
-
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
|
|
147
|
-
forecast_idata = ss_mod.forecast(
|
|
148
|
-
idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
assert forecast_idata.coords["time"].values.shape == (10,)
|
|
152
|
-
assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
|
|
153
|
-
assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
|
|
154
|
-
|
|
155
|
-
assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
|
|
156
|
-
assert not np.any(np.isnan(forecast_idata.forecast_observed.values))
|