pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/deserialize.py +224 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/prior.py +1356 -0
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.7.dist-info/METADATA +321 -0
- pymc_extras-0.2.7.dist-info/RECORD +66 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
- pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
tests/model/test_model_api.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
|
|
4
|
-
import pymc_extras as pmx
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def test_logp():
|
|
8
|
-
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
|
|
9
|
-
and a functional syntax. Checks whether the kwarg `coords` can be passed.
|
|
10
|
-
"""
|
|
11
|
-
coords = {"obs": ["a", "b"]}
|
|
12
|
-
|
|
13
|
-
with pm.Model(coords=coords) as model:
|
|
14
|
-
pm.Normal("x", 0.0, 1.0, dims="obs")
|
|
15
|
-
|
|
16
|
-
@pmx.as_model(coords=coords)
|
|
17
|
-
def model_wrapped():
|
|
18
|
-
pm.Normal("x", 0.0, 1.0, dims="obs")
|
|
19
|
-
|
|
20
|
-
mw = model_wrapped()
|
|
21
|
-
|
|
22
|
-
@pmx.as_model()
|
|
23
|
-
def model_wrapped2():
|
|
24
|
-
pm.Normal("x", 0.0, 1.0, dims="obs")
|
|
25
|
-
|
|
26
|
-
mw2 = model_wrapped2(coords=coords)
|
|
27
|
-
|
|
28
|
-
@pmx.as_model()
|
|
29
|
-
def model_wrapped3(mu):
|
|
30
|
-
pm.Normal("x", mu, 1.0, dims="obs")
|
|
31
|
-
|
|
32
|
-
mw3 = model_wrapped3(0.0, coords=coords)
|
|
33
|
-
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
|
|
34
|
-
|
|
35
|
-
np.testing.assert_equal(model.point_logps(), mw.point_logps())
|
|
36
|
-
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
|
|
37
|
-
assert mw3["mu"] in mw3.data_vars
|
|
38
|
-
assert "mu" not in mw4
|
tests/statespace/__init__.py
DELETED
|
File without changes
|
tests/statespace/test_ETS.py
DELETED
|
@@ -1,411 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pytensor
|
|
3
|
-
import pytest
|
|
4
|
-
import statsmodels.api as sm
|
|
5
|
-
|
|
6
|
-
from numpy.testing import assert_allclose
|
|
7
|
-
from pytensor.graph.basic import explicit_graph_inputs
|
|
8
|
-
from scipy import linalg
|
|
9
|
-
|
|
10
|
-
from pymc_extras.statespace.models.ETS import BayesianETS
|
|
11
|
-
from pymc_extras.statespace.utils.constants import LONG_MATRIX_NAMES
|
|
12
|
-
from tests.statespace.utilities.shared_fixtures import rng
|
|
13
|
-
from tests.statespace.utilities.test_helpers import load_nile_test_data
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@pytest.fixture(scope="session")
|
|
17
|
-
def data():
|
|
18
|
-
return load_nile_test_data()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def tests_invalid_order_raises():
|
|
22
|
-
# Order must be length 3
|
|
23
|
-
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
|
|
24
|
-
BayesianETS(order=("A", "N"))
|
|
25
|
-
|
|
26
|
-
# Order must be strings
|
|
27
|
-
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
|
|
28
|
-
BayesianETS(order=(2, 1, 1))
|
|
29
|
-
|
|
30
|
-
# Only additive errors allowed
|
|
31
|
-
with pytest.raises(ValueError, match="Only additive errors are supported"):
|
|
32
|
-
BayesianETS(order=("M", "N", "N"))
|
|
33
|
-
|
|
34
|
-
# Trend must be A or Ad
|
|
35
|
-
with pytest.raises(ValueError, match="Invalid trend specification"):
|
|
36
|
-
BayesianETS(order=("A", "P", "N"))
|
|
37
|
-
|
|
38
|
-
# Seasonal must be A or N
|
|
39
|
-
with pytest.raises(ValueError, match="Invalid seasonal specification"):
|
|
40
|
-
BayesianETS(order=("A", "Ad", "M"))
|
|
41
|
-
|
|
42
|
-
# seasonal_periods must be provided if seasonal is requested
|
|
43
|
-
with pytest.raises(ValueError, match="If seasonal is True, seasonal_periods must be provided."):
|
|
44
|
-
BayesianETS(order=("A", "Ad", "A"))
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
orders = (
|
|
48
|
-
("A", "N", "N"),
|
|
49
|
-
("A", "A", "N"),
|
|
50
|
-
("A", "Ad", "N"),
|
|
51
|
-
("A", "N", "A"),
|
|
52
|
-
("A", "A", "A"),
|
|
53
|
-
("A", "Ad", "A"),
|
|
54
|
-
)
|
|
55
|
-
order_names = (
|
|
56
|
-
"Basic",
|
|
57
|
-
"Trend",
|
|
58
|
-
"Damped Trend",
|
|
59
|
-
"Seasonal",
|
|
60
|
-
"Trend and Seasonal",
|
|
61
|
-
"Trend, Damped Trend, Seasonal",
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
order_expected_flags = (
|
|
65
|
-
{"trend": False, "damped_trend": False, "seasonal": False},
|
|
66
|
-
{"trend": True, "damped_trend": False, "seasonal": False},
|
|
67
|
-
{"trend": True, "damped_trend": True, "seasonal": False},
|
|
68
|
-
{"trend": False, "damped_trend": False, "seasonal": True},
|
|
69
|
-
{"trend": True, "damped_trend": False, "seasonal": True},
|
|
70
|
-
{"trend": True, "damped_trend": True, "seasonal": True},
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
order_params = (
|
|
74
|
-
["alpha", "initial_level"],
|
|
75
|
-
["alpha", "initial_level", "beta", "initial_trend"],
|
|
76
|
-
["alpha", "initial_level", "beta", "initial_trend", "phi"],
|
|
77
|
-
["alpha", "initial_level", "gamma", "initial_seasonal"],
|
|
78
|
-
["alpha", "initial_level", "beta", "initial_trend", "gamma", "initial_seasonal"],
|
|
79
|
-
["alpha", "initial_level", "beta", "initial_trend", "gamma", "initial_seasonal", "phi"],
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@pytest.mark.parametrize(
|
|
84
|
-
"order, expected_flags", zip(orders, order_expected_flags), ids=order_names
|
|
85
|
-
)
|
|
86
|
-
def test_order_flags(order, expected_flags):
|
|
87
|
-
mod = BayesianETS(order=order, seasonal_periods=4)
|
|
88
|
-
for key, value in expected_flags.items():
|
|
89
|
-
assert getattr(mod, key) == value
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@pytest.mark.parametrize("order, expected_params", zip(orders, order_params), ids=order_names)
|
|
93
|
-
def test_param_info(order: tuple[str, str, str], expected_params):
|
|
94
|
-
mod = BayesianETS(order=order, seasonal_periods=4)
|
|
95
|
-
|
|
96
|
-
all_expected_params = [*expected_params, "sigma_state", "P0"]
|
|
97
|
-
assert all(param in mod.param_names for param in all_expected_params)
|
|
98
|
-
assert all(param in all_expected_params for param in mod.param_names)
|
|
99
|
-
assert all(
|
|
100
|
-
mod.param_info[param]["dims"] is None
|
|
101
|
-
for param in expected_params
|
|
102
|
-
if "seasonal" not in param
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
@pytest.mark.parametrize("order, expected_params", zip(orders, order_params), ids=order_names)
|
|
107
|
-
@pytest.mark.parametrize("use_transformed", [True, False], ids=["transformed", "untransformed"])
|
|
108
|
-
def test_statespace_matrices(
|
|
109
|
-
rng, order: tuple[str, str, str], expected_params: list[str], use_transformed: bool
|
|
110
|
-
):
|
|
111
|
-
seasonal_periods = np.random.randint(3, 12)
|
|
112
|
-
mod = BayesianETS(
|
|
113
|
-
order=order,
|
|
114
|
-
seasonal_periods=seasonal_periods,
|
|
115
|
-
measurement_error=True,
|
|
116
|
-
use_transformed_parameterization=use_transformed,
|
|
117
|
-
)
|
|
118
|
-
expected_states = 2 + int(order[1] != "N") + int(order[2] != "N") * seasonal_periods
|
|
119
|
-
|
|
120
|
-
test_values = {
|
|
121
|
-
"alpha": rng.beta(1, 1),
|
|
122
|
-
"beta": rng.beta(1, 1),
|
|
123
|
-
"gamma": rng.beta(1, 1),
|
|
124
|
-
"phi": rng.beta(1, 1),
|
|
125
|
-
"sigma_state": rng.normal() ** 2,
|
|
126
|
-
"sigma_obs": rng.normal() ** 2,
|
|
127
|
-
"initial_level": rng.normal() ** 2,
|
|
128
|
-
"initial_trend": rng.normal() ** 2,
|
|
129
|
-
"initial_seasonal": np.ones(seasonal_periods),
|
|
130
|
-
"initial_state_cov": np.eye(expected_states),
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
matrices = x0, P0, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()
|
|
134
|
-
|
|
135
|
-
assert x0.type.shape == (expected_states,)
|
|
136
|
-
assert P0.type.shape == (expected_states, expected_states)
|
|
137
|
-
assert c.type.shape == (expected_states,)
|
|
138
|
-
assert d.type.shape == (1,)
|
|
139
|
-
assert T.type.shape == (expected_states, expected_states)
|
|
140
|
-
assert Z.type.shape == (1, expected_states)
|
|
141
|
-
assert R.type.shape == (expected_states, 1)
|
|
142
|
-
assert H.type.shape == (1, 1)
|
|
143
|
-
assert Q.type.shape == (1, 1)
|
|
144
|
-
|
|
145
|
-
inputs = list(explicit_graph_inputs(matrices))
|
|
146
|
-
input_names = [x.name for x in inputs]
|
|
147
|
-
assert all(name in input_names for name in expected_params)
|
|
148
|
-
|
|
149
|
-
f_matrices = pytensor.function(inputs, matrices)
|
|
150
|
-
[x0, P0, c, d, T, Z, R, H, Q] = f_matrices(**{name: test_values[name] for name in input_names})
|
|
151
|
-
|
|
152
|
-
assert_allclose(H, np.eye(1) * test_values["sigma_obs"] ** 2)
|
|
153
|
-
assert_allclose(Q, np.eye(1) * test_values["sigma_state"] ** 2)
|
|
154
|
-
|
|
155
|
-
R_val = np.zeros((expected_states, 1))
|
|
156
|
-
R_val[0] = 1.0 - test_values["alpha"]
|
|
157
|
-
R_val[1] = test_values["alpha"]
|
|
158
|
-
|
|
159
|
-
Z_val = np.zeros((1, expected_states))
|
|
160
|
-
Z_val[0, 0] = 1.0
|
|
161
|
-
Z_val[0, 1] = 1.0
|
|
162
|
-
|
|
163
|
-
x0_val = np.zeros((expected_states,))
|
|
164
|
-
x0_val[1] = test_values["initial_level"]
|
|
165
|
-
|
|
166
|
-
if order[1] == "N":
|
|
167
|
-
T_val = np.array([[0.0, 0.0], [0.0, 1.0]])
|
|
168
|
-
else:
|
|
169
|
-
x0_val[2] = test_values["initial_trend"]
|
|
170
|
-
R_val[2] = (
|
|
171
|
-
test_values["beta"] if use_transformed else test_values["beta"] * test_values["alpha"]
|
|
172
|
-
)
|
|
173
|
-
T_val = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]])
|
|
174
|
-
|
|
175
|
-
if order[1] == "Ad":
|
|
176
|
-
T_val[1:, -1] *= test_values["phi"]
|
|
177
|
-
|
|
178
|
-
if order[2] == "A":
|
|
179
|
-
x0_val[2 + int(order[1] != "N") :] = test_values["initial_seasonal"]
|
|
180
|
-
gamma = (
|
|
181
|
-
test_values["gamma"]
|
|
182
|
-
if use_transformed
|
|
183
|
-
else (1 - test_values["alpha"]) * test_values["gamma"]
|
|
184
|
-
)
|
|
185
|
-
R_val[2 + int(order[1] != "N")] = gamma
|
|
186
|
-
R_val[0] = R_val[0] - gamma
|
|
187
|
-
|
|
188
|
-
S = np.eye(seasonal_periods, k=-1)
|
|
189
|
-
S[0, -1] = 1.0
|
|
190
|
-
Z_val[0, 2 + int(order[1] != "N")] = 1.0
|
|
191
|
-
else:
|
|
192
|
-
S = np.eye(0)
|
|
193
|
-
|
|
194
|
-
T_val = linalg.block_diag(T_val, S)
|
|
195
|
-
|
|
196
|
-
assert_allclose(x0, x0_val)
|
|
197
|
-
assert_allclose(T, T_val)
|
|
198
|
-
assert_allclose(R, R_val)
|
|
199
|
-
assert_allclose(Z, Z_val)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
@pytest.mark.parametrize("order, params", zip(orders, order_params), ids=order_names)
|
|
203
|
-
def test_statespace_matches_statsmodels(rng, order: tuple[str, str, str], params):
|
|
204
|
-
seasonal_periods = rng.integers(3, 12)
|
|
205
|
-
data = rng.normal(size=(100,))
|
|
206
|
-
mod = BayesianETS(
|
|
207
|
-
order=order,
|
|
208
|
-
seasonal_periods=seasonal_periods,
|
|
209
|
-
measurement_error=False,
|
|
210
|
-
use_transformed_parameterization=True,
|
|
211
|
-
)
|
|
212
|
-
sm_mod = sm.tsa.statespace.ExponentialSmoothing(
|
|
213
|
-
data,
|
|
214
|
-
trend=mod.trend,
|
|
215
|
-
damped_trend=mod.damped_trend,
|
|
216
|
-
seasonal=seasonal_periods if mod.seasonal else None,
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
simplex_params = ["alpha", "beta", "gamma"]
|
|
220
|
-
test_values = dict(zip(simplex_params, rng.dirichlet(alpha=np.ones(3))))
|
|
221
|
-
test_values["phi"] = rng.beta(1, 1)
|
|
222
|
-
|
|
223
|
-
test_values["initial_level"] = rng.normal()
|
|
224
|
-
test_values["initial_trend"] = rng.normal()
|
|
225
|
-
test_values["initial_seasonal"] = rng.normal(size=seasonal_periods)
|
|
226
|
-
test_values["initial_state_cov"] = np.eye(mod.k_states)
|
|
227
|
-
test_values["sigma_state"] = 1.0
|
|
228
|
-
|
|
229
|
-
sm_test_values = test_values.copy()
|
|
230
|
-
sm_test_values["smoothing_level"] = test_values["alpha"]
|
|
231
|
-
sm_test_values["smoothing_trend"] = test_values["beta"]
|
|
232
|
-
sm_test_values["smoothing_seasonal"] = test_values["gamma"]
|
|
233
|
-
sm_test_values["damping_trend"] = test_values["phi"]
|
|
234
|
-
sm_test_values["initial_seasonal"] = test_values["initial_seasonal"][0]
|
|
235
|
-
for i in range(1, seasonal_periods):
|
|
236
|
-
sm_test_values[f"initial_seasonal.L{i}"] = test_values["initial_seasonal"][i]
|
|
237
|
-
|
|
238
|
-
vals = [
|
|
239
|
-
np.atleast_1d(test_values[name])
|
|
240
|
-
for name in ["initial_level", "initial_trend", "initial_seasonal"]
|
|
241
|
-
]
|
|
242
|
-
x0 = np.concatenate([[0.0], *vals])
|
|
243
|
-
|
|
244
|
-
mask = [True, True, order[1] != "N", *(order[2] != "N",) * seasonal_periods]
|
|
245
|
-
|
|
246
|
-
sm_mod.initialize_known(initial_state=x0[mask], initial_state_cov=np.eye(mod.k_states))
|
|
247
|
-
sm_mod.fit_constrained({name: sm_test_values[name] for name in sm_mod.param_names})
|
|
248
|
-
|
|
249
|
-
matrices = mod._unpack_statespace_with_placeholders()
|
|
250
|
-
inputs = list(explicit_graph_inputs(matrices))
|
|
251
|
-
input_names = [x.name for x in inputs]
|
|
252
|
-
|
|
253
|
-
f_matrices = pytensor.function(inputs, matrices)
|
|
254
|
-
test_values_subset = {name: test_values[name] for name in input_names}
|
|
255
|
-
|
|
256
|
-
matrices = f_matrices(**test_values_subset)
|
|
257
|
-
sm_matrices = [sm_mod.ssm[name] for name in LONG_MATRIX_NAMES[2:]]
|
|
258
|
-
|
|
259
|
-
for matrix, sm_matrix, name in zip(matrices[2:], sm_matrices, LONG_MATRIX_NAMES[2:]):
|
|
260
|
-
assert_allclose(matrix, sm_matrix, err_msg=f"{name} does not match")
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
@pytest.mark.parametrize("order, params", zip(orders, order_params), ids=order_names)
|
|
264
|
-
@pytest.mark.parametrize("dense_cov", [True, False], ids=["dense", "diagonal"])
|
|
265
|
-
def test_ETS_with_multiple_endog(rng, order, params, dense_cov):
|
|
266
|
-
seasonal_periods = 4
|
|
267
|
-
mod = BayesianETS(
|
|
268
|
-
order=order,
|
|
269
|
-
seasonal_periods=seasonal_periods,
|
|
270
|
-
measurement_error=False,
|
|
271
|
-
use_transformed_parameterization=True,
|
|
272
|
-
dense_innovation_covariance=dense_cov,
|
|
273
|
-
endog_names=["A", "B"],
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
single_mod = BayesianETS(
|
|
277
|
-
order=order,
|
|
278
|
-
seasonal_periods=seasonal_periods,
|
|
279
|
-
measurement_error=False,
|
|
280
|
-
use_transformed_parameterization=True,
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
simplex_params = ["alpha", "beta", "gamma"]
|
|
284
|
-
test_values = dict(zip(simplex_params, rng.dirichlet(alpha=np.ones(3), size=(mod.k_endog,)).T))
|
|
285
|
-
test_values["phi"] = rng.beta(1, 1, size=(mod.k_endog,))
|
|
286
|
-
|
|
287
|
-
test_values["initial_level"] = rng.normal(
|
|
288
|
-
size=mod.k_endog,
|
|
289
|
-
)
|
|
290
|
-
test_values["initial_trend"] = rng.normal(
|
|
291
|
-
size=mod.k_endog,
|
|
292
|
-
)
|
|
293
|
-
test_values["initial_seasonal"] = rng.normal(size=(mod.k_endog, seasonal_periods))
|
|
294
|
-
test_values["initial_state_cov"] = np.eye(mod.k_states)
|
|
295
|
-
|
|
296
|
-
if not dense_cov:
|
|
297
|
-
test_values["sigma_state"] = np.ones(
|
|
298
|
-
mod.k_endog,
|
|
299
|
-
)
|
|
300
|
-
else:
|
|
301
|
-
L = np.random.normal(size=(mod.k_endog, mod.k_endog))
|
|
302
|
-
test_values["state_cov"] = L @ L.T
|
|
303
|
-
|
|
304
|
-
# Compile functions for the joined model
|
|
305
|
-
matrices_pt = mod._unpack_statespace_with_placeholders()
|
|
306
|
-
inputs = list(explicit_graph_inputs(matrices_pt))
|
|
307
|
-
input_names = [x.name for x in inputs]
|
|
308
|
-
|
|
309
|
-
test_values_subset = {name: test_values[name] for name in input_names}
|
|
310
|
-
f_matrices = pytensor.function(inputs, matrices_pt)
|
|
311
|
-
|
|
312
|
-
matrices = f_matrices(**test_values_subset)
|
|
313
|
-
|
|
314
|
-
# Compile functions for the single model
|
|
315
|
-
single_matrices_pt = single_mod._unpack_statespace_with_placeholders()
|
|
316
|
-
single_inputs = list(explicit_graph_inputs(single_matrices_pt))
|
|
317
|
-
single_input_names = [x.name for x in single_inputs]
|
|
318
|
-
|
|
319
|
-
cursor = 0
|
|
320
|
-
single_test_values_subsets = []
|
|
321
|
-
for i in range(mod.k_endog):
|
|
322
|
-
single_slice = slice(cursor, cursor + single_mod.k_states)
|
|
323
|
-
d = {
|
|
324
|
-
name: (
|
|
325
|
-
test_values[name][i]
|
|
326
|
-
if name != "initial_state_cov"
|
|
327
|
-
else test_values_subset[name][single_slice, single_slice]
|
|
328
|
-
)
|
|
329
|
-
for name in single_input_names
|
|
330
|
-
if name != "sigma_state"
|
|
331
|
-
}
|
|
332
|
-
if dense_cov:
|
|
333
|
-
d["sigma_state"] = np.sqrt(test_values["state_cov"][i, i])
|
|
334
|
-
else:
|
|
335
|
-
d["sigma_state"] = test_values["sigma_state"][i]
|
|
336
|
-
single_test_values_subsets.append(d)
|
|
337
|
-
cursor += single_mod.k_states
|
|
338
|
-
|
|
339
|
-
f_single_matrices = pytensor.function(single_inputs, single_matrices_pt)
|
|
340
|
-
single_matrices = [f_single_matrices(**d) for d in single_test_values_subsets]
|
|
341
|
-
names = [x.name for x in matrices_pt]
|
|
342
|
-
|
|
343
|
-
for i, (x1, name) in enumerate(zip(matrices, names)):
|
|
344
|
-
cursor = 0
|
|
345
|
-
for j in range(mod.k_endog):
|
|
346
|
-
x2 = single_matrices[j][i]
|
|
347
|
-
state_slice = slice(cursor, cursor + single_mod.k_states)
|
|
348
|
-
obs_slice = slice(j, j + 1) # Also endog_slice -- it's doing double duty
|
|
349
|
-
if name in ["state_intercept", "initial_state"]:
|
|
350
|
-
assert_allclose(x1[state_slice], x2, err_msg=f"{name} does not match for case {j}")
|
|
351
|
-
elif name in ["P0", "initial_state_cov", "transition"]:
|
|
352
|
-
assert_allclose(
|
|
353
|
-
x1[state_slice, state_slice], x2, err_msg=f"{name} does not match for case {j}"
|
|
354
|
-
)
|
|
355
|
-
elif name == "selection":
|
|
356
|
-
assert_allclose(
|
|
357
|
-
x1[state_slice, obs_slice], x2, err_msg=f"{name} does not match for case {j}"
|
|
358
|
-
)
|
|
359
|
-
elif name == "design":
|
|
360
|
-
assert_allclose(
|
|
361
|
-
x1[obs_slice, state_slice], x2, err_msg=f"{name} does not match for case {j}"
|
|
362
|
-
)
|
|
363
|
-
elif name == "obs_intercept":
|
|
364
|
-
assert_allclose(x1[obs_slice], x2, err_msg=f"{name} does not match for case {j}")
|
|
365
|
-
elif name in ["obs_cov", "state_cov"]:
|
|
366
|
-
assert_allclose(
|
|
367
|
-
x1[obs_slice, obs_slice], x2, err_msg=f"{name} does not match for case {j}"
|
|
368
|
-
)
|
|
369
|
-
else:
|
|
370
|
-
raise ValueError(f"You forgot {name} !")
|
|
371
|
-
|
|
372
|
-
cursor += single_mod.k_states
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
def test_ETS_stationary_initialization():
|
|
376
|
-
mod = BayesianETS(
|
|
377
|
-
order=("A", "Ad", "A"),
|
|
378
|
-
seasonal_periods=4,
|
|
379
|
-
stationary_initialization=True,
|
|
380
|
-
initialization_dampening=0.66,
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
matrices = mod._unpack_statespace_with_placeholders()
|
|
384
|
-
inputs = list(explicit_graph_inputs(matrices))
|
|
385
|
-
input_names = [x.name for x in inputs]
|
|
386
|
-
|
|
387
|
-
# Make sure the stationary_dampening dummy variables was completely rewritten away
|
|
388
|
-
assert "stationary_dampening" not in input_names
|
|
389
|
-
|
|
390
|
-
# P0 should have been removed from param names
|
|
391
|
-
assert "P0" not in mod.param_names
|
|
392
|
-
assert "P0" not in mod.param_info.keys()
|
|
393
|
-
|
|
394
|
-
f = pytensor.function(inputs, matrices, mode="FAST_COMPILE")
|
|
395
|
-
test_values = f(**{x.name: np.full(x.type.shape, 0.5) for x in inputs})
|
|
396
|
-
outputs = {name: val for name, val in zip(LONG_MATRIX_NAMES, test_values)}
|
|
397
|
-
|
|
398
|
-
# Make sure that the transition matrix has ones in the expected positions, not the model dampening factor
|
|
399
|
-
assert outputs["transition"][1, 1] == 1.0
|
|
400
|
-
assert outputs["transition"][2, 2] == 0.5 # phi = 0.5 -- trend is dampened anyway
|
|
401
|
-
assert outputs["transition"][3, -1] == 1.0
|
|
402
|
-
|
|
403
|
-
# P0 should be equal to the solution to the Lyapunov equation using the dampening factors in the transition matrix
|
|
404
|
-
T_stationary = outputs["transition"].copy()
|
|
405
|
-
T_stationary[1, 1] = mod.initialization_dampening
|
|
406
|
-
T_stationary[3, -1] = mod.initialization_dampening
|
|
407
|
-
|
|
408
|
-
R, Q = outputs["selection"], outputs["state_cov"]
|
|
409
|
-
P0_expected = linalg.solve_discrete_lyapunov(T_stationary, R @ Q @ R.T)
|
|
410
|
-
|
|
411
|
-
assert_allclose(outputs["initial_state_cov"], P0_expected, rtol=1e-8, atol=1e-8)
|