pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/deserialize.py +224 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/prior.py +1356 -0
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.7.dist-info/METADATA +321 -0
- pymc_extras-0.2.7.dist-info/RECORD +66 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
- pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
tests/statespace/test_SARIMAX.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
|
1
|
-
from itertools import combinations
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pymc as pm
|
|
5
|
-
import pytensor
|
|
6
|
-
import pytensor.tensor as pt
|
|
7
|
-
import pytest
|
|
8
|
-
import statsmodels.api as sm
|
|
9
|
-
|
|
10
|
-
from numpy.testing import assert_allclose, assert_array_less
|
|
11
|
-
|
|
12
|
-
from pymc_extras.statespace import BayesianSARIMA
|
|
13
|
-
from pymc_extras.statespace.models.utilities import (
|
|
14
|
-
make_harvey_state_names,
|
|
15
|
-
make_SARIMA_transition_matrix,
|
|
16
|
-
)
|
|
17
|
-
from pymc_extras.statespace.utils.constants import (
|
|
18
|
-
SARIMAX_STATE_STRUCTURES,
|
|
19
|
-
SHORT_NAME_TO_LONG,
|
|
20
|
-
)
|
|
21
|
-
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
22
|
-
rng,
|
|
23
|
-
)
|
|
24
|
-
from tests.statespace.utilities.test_helpers import (
|
|
25
|
-
load_nile_test_data,
|
|
26
|
-
make_stationary_params,
|
|
27
|
-
simulate_from_numpy_model,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
floatX = pytensor.config.floatX
|
|
31
|
-
ATOL = 1e-8 if floatX.endswith("64") else 1e-6
|
|
32
|
-
RTOL = 0 if floatX.endswith("64") else 1e-6
|
|
33
|
-
|
|
34
|
-
test_state_names = [
|
|
35
|
-
["data", "state_1", "state_2"],
|
|
36
|
-
["data", "data_star", "state_star_1", "state_star_2"],
|
|
37
|
-
["data", "D1.data", "data_star", "state_star_1", "state_star_2"],
|
|
38
|
-
["data", "D1.data", "D1^2.data", "data_star", "state_star_1", "state_star_2"],
|
|
39
|
-
["data", "state_1", "state_2", "state_3"],
|
|
40
|
-
["data", "state_1", "state_2", "state_3", "state_4", "state_5", "state_6", "state_7"],
|
|
41
|
-
[
|
|
42
|
-
"data",
|
|
43
|
-
"state_1",
|
|
44
|
-
"state_2",
|
|
45
|
-
"state_3",
|
|
46
|
-
"state_4",
|
|
47
|
-
"state_5",
|
|
48
|
-
"state_6",
|
|
49
|
-
"state_7",
|
|
50
|
-
"state_8",
|
|
51
|
-
"state_9",
|
|
52
|
-
"state_10",
|
|
53
|
-
],
|
|
54
|
-
[
|
|
55
|
-
"data",
|
|
56
|
-
"L1.data",
|
|
57
|
-
"L2.data",
|
|
58
|
-
"L3.data",
|
|
59
|
-
"data_star",
|
|
60
|
-
"state_star_1",
|
|
61
|
-
"state_star_2",
|
|
62
|
-
"state_star_3",
|
|
63
|
-
],
|
|
64
|
-
[
|
|
65
|
-
"data",
|
|
66
|
-
"L1.data",
|
|
67
|
-
"L2.data",
|
|
68
|
-
"L3.data",
|
|
69
|
-
"D4.data",
|
|
70
|
-
"L1D4.data",
|
|
71
|
-
"L2D4.data",
|
|
72
|
-
"L3D4.data",
|
|
73
|
-
"data_star",
|
|
74
|
-
"state_star_1",
|
|
75
|
-
"state_star_2",
|
|
76
|
-
"state_star_3",
|
|
77
|
-
],
|
|
78
|
-
[
|
|
79
|
-
"data",
|
|
80
|
-
"D1.data",
|
|
81
|
-
"L1D1.data",
|
|
82
|
-
"L2D1.data",
|
|
83
|
-
"L3D1.data",
|
|
84
|
-
"data_star",
|
|
85
|
-
"state_star_1",
|
|
86
|
-
"state_star_2",
|
|
87
|
-
"state_star_3",
|
|
88
|
-
"state_star_4",
|
|
89
|
-
"state_star_5",
|
|
90
|
-
],
|
|
91
|
-
[
|
|
92
|
-
"data",
|
|
93
|
-
"D1.data",
|
|
94
|
-
"D1^2.data",
|
|
95
|
-
"L1D1^2.data",
|
|
96
|
-
"L2D1^2.data",
|
|
97
|
-
"L3D1^2.data",
|
|
98
|
-
"data_star",
|
|
99
|
-
"state_star_1",
|
|
100
|
-
"state_star_2",
|
|
101
|
-
"state_star_3",
|
|
102
|
-
"state_star_4",
|
|
103
|
-
"state_star_5",
|
|
104
|
-
],
|
|
105
|
-
[
|
|
106
|
-
"data",
|
|
107
|
-
"D1.data",
|
|
108
|
-
"D1^2.data",
|
|
109
|
-
"L1D1^2.data",
|
|
110
|
-
"L2D1^2.data",
|
|
111
|
-
"L3D1^2.data",
|
|
112
|
-
"D1^2D4.data",
|
|
113
|
-
"L1D1^2D4.data",
|
|
114
|
-
"L2D1^2D4.data",
|
|
115
|
-
"L3D1^2D4.data",
|
|
116
|
-
"data_star",
|
|
117
|
-
"state_star_1",
|
|
118
|
-
"state_star_2",
|
|
119
|
-
"state_star_3",
|
|
120
|
-
"state_star_4",
|
|
121
|
-
"state_star_5",
|
|
122
|
-
],
|
|
123
|
-
[
|
|
124
|
-
"data",
|
|
125
|
-
"D1.data",
|
|
126
|
-
"L1D1.data",
|
|
127
|
-
"L2D1.data",
|
|
128
|
-
"D1D3.data",
|
|
129
|
-
"L1D1D3.data",
|
|
130
|
-
"L2D1D3.data",
|
|
131
|
-
"D1D3^2.data",
|
|
132
|
-
"L1D1D3^2.data",
|
|
133
|
-
"L2D1D3^2.data",
|
|
134
|
-
"data_star",
|
|
135
|
-
"state_star_1",
|
|
136
|
-
"state_star_2",
|
|
137
|
-
"state_star_3",
|
|
138
|
-
"state_star_4",
|
|
139
|
-
],
|
|
140
|
-
["data", "data_star"] + [f"state_star_{i+1}" for i in range(26)],
|
|
141
|
-
]
|
|
142
|
-
|
|
143
|
-
test_orders = [
|
|
144
|
-
(2, 0, 2, 0, 0, 0, 0),
|
|
145
|
-
(2, 1, 2, 0, 0, 0, 0),
|
|
146
|
-
(2, 2, 2, 0, 0, 0, 0),
|
|
147
|
-
(2, 3, 2, 0, 0, 0, 0),
|
|
148
|
-
(0, 0, 0, 1, 0, 0, 4),
|
|
149
|
-
(0, 0, 0, 2, 0, 1, 4),
|
|
150
|
-
(2, 0, 2, 2, 0, 2, 4),
|
|
151
|
-
(0, 0, 0, 1, 1, 0, 4),
|
|
152
|
-
(0, 0, 0, 1, 2, 0, 4),
|
|
153
|
-
(1, 1, 1, 1, 1, 1, 4),
|
|
154
|
-
(1, 2, 1, 1, 1, 1, 4),
|
|
155
|
-
(1, 2, 1, 1, 2, 1, 4),
|
|
156
|
-
(1, 1, 1, 1, 3, 1, 3),
|
|
157
|
-
(2, 1, 2, 2, 0, 2, 12),
|
|
158
|
-
]
|
|
159
|
-
|
|
160
|
-
ids = [f"p={p},d={d},q={q},P={P},D={D},Q={Q},S={S}" for (p, d, q, P, D, Q, S) in test_orders]
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
@pytest.fixture
|
|
164
|
-
def data():
|
|
165
|
-
return load_nile_test_data()
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
@pytest.fixture(scope="session")
|
|
169
|
-
def arima_mod():
|
|
170
|
-
return BayesianSARIMA(order=(2, 0, 1), stationary_initialization=True, verbose=False)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
@pytest.fixture(scope="session")
|
|
174
|
-
def pymc_mod(arima_mod):
|
|
175
|
-
data = load_nile_test_data()
|
|
176
|
-
|
|
177
|
-
with pm.Model(coords=arima_mod.coords) as pymc_mod:
|
|
178
|
-
# x0 = pm.Normal('x0', dims=['state'])
|
|
179
|
-
# P0_diag = pm.Gamma('P0_diag', alpha=2, beta=1, dims=['state'])
|
|
180
|
-
# P0 = pm.Deterministic('P0', pt.diag(P0_diag), dims=['state', 'state_aux'])
|
|
181
|
-
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"])
|
|
182
|
-
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
|
|
183
|
-
sigma_state = pm.Exponential("sigma_state", 0.5)
|
|
184
|
-
arima_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
|
|
185
|
-
|
|
186
|
-
return pymc_mod
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@pytest.fixture(scope="session")
|
|
190
|
-
def arima_mod_interp():
|
|
191
|
-
return BayesianSARIMA(
|
|
192
|
-
order=(3, 0, 3),
|
|
193
|
-
stationary_initialization=False,
|
|
194
|
-
verbose=False,
|
|
195
|
-
state_structure="interpretable",
|
|
196
|
-
measurement_error=True,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
@pytest.fixture(scope="session")
|
|
201
|
-
def pymc_mod_interp(arima_mod_interp):
|
|
202
|
-
data = load_nile_test_data()
|
|
203
|
-
|
|
204
|
-
with pm.Model(coords=arima_mod_interp.coords) as pymc_mod:
|
|
205
|
-
x0 = pm.Normal("x0", dims=["state"])
|
|
206
|
-
P0_sigma = pm.Exponential("P0_sigma", 1)
|
|
207
|
-
P0 = pm.Deterministic(
|
|
208
|
-
"P0", pt.eye(arima_mod_interp.k_states) * P0_sigma, dims=["state", "state_aux"]
|
|
209
|
-
)
|
|
210
|
-
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"])
|
|
211
|
-
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
|
|
212
|
-
sigma_state = pm.Exponential("sigma_state", 0.5)
|
|
213
|
-
sigma_obs = pm.Exponential("sigma_obs", 0.1)
|
|
214
|
-
|
|
215
|
-
arima_mod_interp.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
|
|
216
|
-
|
|
217
|
-
return pymc_mod
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
@pytest.mark.parametrize(
|
|
221
|
-
"p,d,q,P,D,Q,S,expected_names",
|
|
222
|
-
[(*order, name) for order, name in zip(test_orders, test_state_names)],
|
|
223
|
-
ids=ids,
|
|
224
|
-
)
|
|
225
|
-
def test_harvey_state_names(p, d, q, P, D, Q, S, expected_names):
|
|
226
|
-
if all([x == 0 for x in [p, d, q, P, D, Q, S]]):
|
|
227
|
-
pytest.skip("Skip all zero case")
|
|
228
|
-
|
|
229
|
-
k_states = max(p + P * S, q + Q * S + 1) + (S * D + d)
|
|
230
|
-
states = make_harvey_state_names(p, d, q, P, D, Q, S)
|
|
231
|
-
|
|
232
|
-
assert len(states) == k_states
|
|
233
|
-
missing_from_expected = set(expected_names) - set(states)
|
|
234
|
-
assert len(missing_from_expected) == 0
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
@pytest.mark.parametrize("p,d,q,P,D,Q,S", test_orders)
|
|
238
|
-
def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S):
|
|
239
|
-
T = make_SARIMA_transition_matrix(p, d, q, P, D, Q, S)
|
|
240
|
-
mod = sm.tsa.SARIMAX(np.random.normal(size=100), order=(p, d, q), seasonal_order=(P, D, Q, S))
|
|
241
|
-
T2 = mod.ssm["transition"]
|
|
242
|
-
|
|
243
|
-
if D > 2:
|
|
244
|
-
pytest.skip("Statsmodels has a bug when D > 2, skip this test.")
|
|
245
|
-
else:
|
|
246
|
-
assert_allclose(T, T2, err_msg="Transition matrix does not match statsmodels")
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
@pytest.mark.parametrize("p, d, q, P, D, Q, S", test_orders, ids=ids)
|
|
250
|
-
@pytest.mark.filterwarnings(
|
|
251
|
-
"ignore:Non-invertible starting MA parameters found.",
|
|
252
|
-
"ignore:Non-stationary starting autoregressive parameters found",
|
|
253
|
-
"ignore:Non-invertible starting seasonal moving average",
|
|
254
|
-
"ignore:Non-stationary starting seasonal autoregressive",
|
|
255
|
-
)
|
|
256
|
-
def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
|
|
257
|
-
sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
|
|
258
|
-
|
|
259
|
-
param_names = sm_sarimax.param_names
|
|
260
|
-
param_d = {name: getattr(np, floatX)(rng.normal(scale=0.1) ** 2) for name in param_names}
|
|
261
|
-
|
|
262
|
-
res = sm_sarimax.fit_constrained(param_d)
|
|
263
|
-
mod = BayesianSARIMA(
|
|
264
|
-
order=(p, d, q), seasonal_order=(P, D, Q, S), verbose=False, stationary_initialization=False
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
with pm.Model() as pm_mod:
|
|
268
|
-
x0 = pm.Normal("x0", shape=(mod.k_states,))
|
|
269
|
-
P0 = pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX))
|
|
270
|
-
|
|
271
|
-
if q > 0:
|
|
272
|
-
pm.Deterministic(
|
|
273
|
-
"ma_params",
|
|
274
|
-
pt.as_tensor_variable(
|
|
275
|
-
np.array([param_d[k] for k in param_d if k.startswith("ma.") and "S." not in k])
|
|
276
|
-
),
|
|
277
|
-
)
|
|
278
|
-
if p > 0:
|
|
279
|
-
pm.Deterministic(
|
|
280
|
-
"ar_params",
|
|
281
|
-
pt.as_tensor_variable(
|
|
282
|
-
np.array([param_d[k] for k in param_d if k.startswith("ar.") and "S." not in k])
|
|
283
|
-
),
|
|
284
|
-
)
|
|
285
|
-
if P > 0:
|
|
286
|
-
pm.Deterministic(
|
|
287
|
-
"seasonal_ar_params",
|
|
288
|
-
pt.as_tensor_variable(
|
|
289
|
-
np.array([param_d[k] for k in param_d if k.startswith("ar.S.")])
|
|
290
|
-
),
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
if Q > 0:
|
|
294
|
-
pm.Deterministic(
|
|
295
|
-
"seasonal_ma_params",
|
|
296
|
-
pt.as_tensor_variable(
|
|
297
|
-
np.array([param_d[k] for k in param_d if k.startswith("ma.S.")])
|
|
298
|
-
),
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
pm.Deterministic("sigma_state", pt.as_tensor_variable(np.sqrt(param_d["sigma2"])))
|
|
302
|
-
|
|
303
|
-
mod._insert_random_variables()
|
|
304
|
-
matrices = pm.draw(mod.subbed_ssm)
|
|
305
|
-
matrix_dict = dict(zip(SHORT_NAME_TO_LONG.values(), matrices))
|
|
306
|
-
|
|
307
|
-
for matrix in ["transition", "selection", "state_cov", "obs_cov", "design"]:
|
|
308
|
-
if matrix == "transition" and D > 2:
|
|
309
|
-
pytest.skip("Statsmodels has a bug when D > 2, skip this test.)")
|
|
310
|
-
assert_allclose(matrix_dict[matrix], sm_sarimax.ssm[matrix], err_msg=f"{matrix} not equal")
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
|
|
314
|
-
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
|
|
315
|
-
rv = pymc_mod[f"{filter_output}_covariance"]
|
|
316
|
-
cov_mats = pm.draw(rv, 100, random_seed=rng)
|
|
317
|
-
w, v = np.linalg.eig(cov_mats)
|
|
318
|
-
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
def test_interpretable_raises_if_d_nonzero():
|
|
322
|
-
with pytest.raises(
|
|
323
|
-
ValueError, match="Cannot use interpretable state structure with statespace differencing"
|
|
324
|
-
):
|
|
325
|
-
BayesianSARIMA(
|
|
326
|
-
order=(2, 1, 1),
|
|
327
|
-
stationary_initialization=True,
|
|
328
|
-
verbose=False,
|
|
329
|
-
state_structure="interpretable",
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_interp):
|
|
334
|
-
with pymc_mod_interp:
|
|
335
|
-
prior = pm.sample_prior_predictive(draws=10)
|
|
336
|
-
|
|
337
|
-
prior_outputs = arima_mod_interp.sample_unconditional_prior(prior)
|
|
338
|
-
ar_lags = prior.prior.coords["ar_lag"].values - 1
|
|
339
|
-
ma_lags = prior.prior.coords["ma_lag"].values - 1
|
|
340
|
-
|
|
341
|
-
# Check the first p states are lags of the previous state
|
|
342
|
-
for t, tm1 in zip(ar_lags[1:], ar_lags[:-1]):
|
|
343
|
-
assert_allclose(
|
|
344
|
-
prior_outputs.prior_latent.isel(state=t).values[1:],
|
|
345
|
-
prior_outputs.prior_latent.isel(state=tm1).values[:-1],
|
|
346
|
-
err_msg=f"State {tm1} is not a lagged version of state {t} (AR lags)",
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
# Check the next p+q states are lags of the innovations
|
|
350
|
-
n = len(ar_lags)
|
|
351
|
-
for t, tm1 in zip(ma_lags[1:], ma_lags[:-1]):
|
|
352
|
-
assert_allclose(
|
|
353
|
-
prior_outputs.prior_latent.isel(state=n + t).values[1:],
|
|
354
|
-
prior_outputs.prior_latent.isel(state=n + tm1).values[:-1],
|
|
355
|
-
err_msg=f"State {n + tm1} is not a lagged version of state {n + t} (MA lags)",
|
|
356
|
-
)
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
@pytest.mark.parametrize("p, d, q, P, D, Q, S", test_orders, ids=ids)
|
|
360
|
-
@pytest.mark.filterwarnings(
|
|
361
|
-
"ignore:Non-invertible starting MA parameters found.",
|
|
362
|
-
"ignore:Non-stationary starting autoregressive parameters found",
|
|
363
|
-
"ignore:Maximum Likelihood optimization failed to converge.",
|
|
364
|
-
)
|
|
365
|
-
def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
|
|
366
|
-
if (d + D) > 0:
|
|
367
|
-
pytest.skip('state_structure = "interpretable" cannot include statespace differences')
|
|
368
|
-
|
|
369
|
-
shared_params = make_stationary_params(data, p, d, q, P, D, Q, S)
|
|
370
|
-
test_values = {}
|
|
371
|
-
|
|
372
|
-
for representation in SARIMAX_STATE_STRUCTURES:
|
|
373
|
-
rng = np.random.default_rng(sum(map(ord, "representation test")))
|
|
374
|
-
mod = BayesianSARIMA(
|
|
375
|
-
order=(p, d, q),
|
|
376
|
-
seasonal_order=(P, D, Q, S),
|
|
377
|
-
stationary_initialization=False,
|
|
378
|
-
verbose=False,
|
|
379
|
-
state_structure=representation,
|
|
380
|
-
)
|
|
381
|
-
shared_params.update(
|
|
382
|
-
{
|
|
383
|
-
"x0": np.zeros(mod.k_states, dtype=floatX),
|
|
384
|
-
"initial_state_cov": np.eye(mod.k_states, dtype=floatX) * 100,
|
|
385
|
-
}
|
|
386
|
-
)
|
|
387
|
-
x, y = simulate_from_numpy_model(mod, rng, shared_params)
|
|
388
|
-
test_values[representation] = y
|
|
389
|
-
|
|
390
|
-
all_pairs = combinations(SARIMAX_STATE_STRUCTURES, r=2)
|
|
391
|
-
for rep_1, rep_2 in all_pairs:
|
|
392
|
-
assert_allclose(
|
|
393
|
-
test_values[rep_1],
|
|
394
|
-
test_values[rep_2],
|
|
395
|
-
err_msg=f"{rep_1} and {rep_2} are not the same",
|
|
396
|
-
atol=ATOL,
|
|
397
|
-
rtol=RTOL,
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
@pytest.mark.parametrize("order, name", [((4, 1, 0, 0), "AR"), ((0, 0, 4, 1), "MA")])
|
|
402
|
-
def test_invalid_order_raises(order, name):
|
|
403
|
-
p, P, q, Q = order
|
|
404
|
-
with pytest.raises(ValueError, match=f"The following {name} and seasonal {name} terms overlap"):
|
|
405
|
-
BayesianSARIMA(order=(p, 0, q), seasonal_order=(P, 0, Q, 4))
|
tests/statespace/test_VARMAX.py
DELETED
|
@@ -1,184 +0,0 @@
|
|
|
1
|
-
from itertools import pairwise, product
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import pymc as pm
|
|
6
|
-
import pytensor
|
|
7
|
-
import pytensor.tensor as pt
|
|
8
|
-
import pytest
|
|
9
|
-
import statsmodels.api as sm
|
|
10
|
-
|
|
11
|
-
from numpy.testing import assert_allclose, assert_array_less
|
|
12
|
-
|
|
13
|
-
from pymc_extras.statespace import BayesianVARMAX
|
|
14
|
-
from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
|
|
15
|
-
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
16
|
-
rng,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
floatX = pytensor.config.floatX
|
|
20
|
-
ps = [0, 1, 2, 3]
|
|
21
|
-
qs = [0, 1, 2, 3]
|
|
22
|
-
orders = list(product(ps, qs))[1:]
|
|
23
|
-
ids = [f"p={x[0]}, q={x[1]}" for x in orders]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@pytest.fixture(scope="session")
|
|
27
|
-
def data():
|
|
28
|
-
df = pd.read_csv(
|
|
29
|
-
"tests/statespace/test_data/statsmodels_macrodata_processed.csv",
|
|
30
|
-
index_col=0,
|
|
31
|
-
parse_dates=True,
|
|
32
|
-
).astype(floatX)
|
|
33
|
-
df.index.freq = df.index.inferred_freq
|
|
34
|
-
return df
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@pytest.fixture(scope="session")
|
|
38
|
-
def varma_mod(data):
|
|
39
|
-
return BayesianVARMAX(
|
|
40
|
-
endog_names=data.columns,
|
|
41
|
-
order=(2, 0),
|
|
42
|
-
stationary_initialization=True,
|
|
43
|
-
verbose=False,
|
|
44
|
-
measurement_error=True,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@pytest.fixture(scope="session")
|
|
49
|
-
def pymc_mod(varma_mod, data):
|
|
50
|
-
with pm.Model(coords=varma_mod.coords) as pymc_mod:
|
|
51
|
-
# x0 = pm.Normal("x0", dims=["state"])
|
|
52
|
-
# P0_diag = pm.Exponential("P0_diag", 1, size=varma_mod.k_states)
|
|
53
|
-
# P0 = pm.Deterministic(
|
|
54
|
-
# "P0", pt.diag(P0_diag), dims=["state", "state_aux"]
|
|
55
|
-
# )
|
|
56
|
-
state_chol, *_ = pm.LKJCholeskyCov(
|
|
57
|
-
"state_chol", n=varma_mod.k_posdef, eta=1, sd_dist=pm.Exponential.dist(1)
|
|
58
|
-
)
|
|
59
|
-
ar_params = pm.Normal(
|
|
60
|
-
"ar_params", mu=0, sigma=0.1, dims=["observed_state", "ar_lag", "observed_state_aux"]
|
|
61
|
-
)
|
|
62
|
-
state_cov = pm.Deterministic(
|
|
63
|
-
"state_cov", state_chol @ state_chol.T, dims=["shock", "shock_aux"]
|
|
64
|
-
)
|
|
65
|
-
sigma_obs = pm.Exponential("sigma_obs", 1, dims=["observed_state"])
|
|
66
|
-
|
|
67
|
-
varma_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
|
|
68
|
-
|
|
69
|
-
return pymc_mod
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@pytest.fixture(scope="session")
|
|
73
|
-
def idata(pymc_mod, rng):
|
|
74
|
-
with pymc_mod:
|
|
75
|
-
idata = pm.sample_prior_predictive(draws=10, random_seed=rng)
|
|
76
|
-
|
|
77
|
-
return idata
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@pytest.mark.parametrize("order", orders, ids=ids)
|
|
81
|
-
@pytest.mark.parametrize("var", ["AR", "MA", "state_cov"])
|
|
82
|
-
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.EstimationWarning")
|
|
83
|
-
def test_VARMAX_param_counts_match_statsmodels(data, order, var):
|
|
84
|
-
p, q = order
|
|
85
|
-
|
|
86
|
-
mod = BayesianVARMAX(k_endog=data.shape[1], order=(p, q), verbose=False)
|
|
87
|
-
sm_var = sm.tsa.VARMAX(data, order=(p, q))
|
|
88
|
-
|
|
89
|
-
count = mod.param_counts[var]
|
|
90
|
-
if var == "state_cov":
|
|
91
|
-
# Statsmodels only counts the lower triangle
|
|
92
|
-
count = mod.k_posdef * (mod.k_posdef - 1)
|
|
93
|
-
assert count == sm_var.parameters[var.lower()]
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@pytest.mark.parametrize("order", orders, ids=ids)
|
|
97
|
-
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.EstimationWarning")
|
|
98
|
-
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
|
99
|
-
def test_VARMAX_update_matches_statsmodels(data, order, rng):
|
|
100
|
-
p, q = order
|
|
101
|
-
|
|
102
|
-
sm_var = sm.tsa.VARMAX(data, order=(p, q))
|
|
103
|
-
|
|
104
|
-
param_counts = [None, *np.cumsum(list(sm_var.parameters.values())).tolist()]
|
|
105
|
-
param_slices = [slice(a, b) for a, b in pairwise(param_counts)]
|
|
106
|
-
param_lists = [trend, ar, ma, reg, state_cov, obs_cov] = [
|
|
107
|
-
sm_var.param_names[idx] for idx in param_slices
|
|
108
|
-
]
|
|
109
|
-
param_d = {
|
|
110
|
-
k: getattr(np, floatX)(rng.normal(scale=0.1) ** 2)
|
|
111
|
-
for param_list in param_lists
|
|
112
|
-
for k in param_list
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
res = sm_var.fit_constrained(param_d)
|
|
116
|
-
|
|
117
|
-
mod = BayesianVARMAX(
|
|
118
|
-
k_endog=data.shape[1],
|
|
119
|
-
order=(p, q),
|
|
120
|
-
verbose=False,
|
|
121
|
-
measurement_error=False,
|
|
122
|
-
stationary_initialization=False,
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
ar_shape = (mod.k_endog, mod.p, mod.k_endog)
|
|
126
|
-
ma_shape = (mod.k_endog, mod.q, mod.k_endog)
|
|
127
|
-
|
|
128
|
-
with pm.Model() as pm_mod:
|
|
129
|
-
x0 = pm.Deterministic("x0", pt.zeros(mod.k_states, dtype=floatX))
|
|
130
|
-
P0 = pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX))
|
|
131
|
-
ma_params = pm.Deterministic(
|
|
132
|
-
"ma_params",
|
|
133
|
-
pt.as_tensor_variable(np.array([param_d[var] for var in ma])).reshape(ma_shape),
|
|
134
|
-
)
|
|
135
|
-
ar_params = pm.Deterministic(
|
|
136
|
-
"ar_params",
|
|
137
|
-
pt.as_tensor_variable(np.array([param_d[var] for var in ar])).reshape(ar_shape),
|
|
138
|
-
)
|
|
139
|
-
state_chol = np.zeros((mod.k_posdef, mod.k_posdef), dtype=floatX)
|
|
140
|
-
state_chol[np.tril_indices(mod.k_posdef)] = np.array([param_d[var] for var in state_cov])
|
|
141
|
-
state_cov = pm.Deterministic("state_cov", pt.as_tensor_variable(state_chol @ state_chol.T))
|
|
142
|
-
mod._insert_random_variables()
|
|
143
|
-
|
|
144
|
-
matrices = pm.draw(mod.subbed_ssm)
|
|
145
|
-
matrix_dict = dict(zip(SHORT_NAME_TO_LONG.values(), matrices))
|
|
146
|
-
|
|
147
|
-
for matrix in ["transition", "selection", "state_cov", "obs_cov", "design"]:
|
|
148
|
-
assert_allclose(matrix_dict[matrix], sm_var.ssm[matrix])
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
|
|
152
|
-
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
|
|
153
|
-
rv = pymc_mod[f"{filter_output}_covariance"]
|
|
154
|
-
cov_mats = pm.draw(rv, 100, random_seed=rng)
|
|
155
|
-
w, v = np.linalg.eig(cov_mats)
|
|
156
|
-
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
parameters = [
|
|
160
|
-
{"n_steps": 10, "shock_size": None},
|
|
161
|
-
{"n_steps": 10, "shock_size": 1.0},
|
|
162
|
-
{"n_steps": 10, "shock_size": np.array([1.0, 0.0, 0.0])},
|
|
163
|
-
{
|
|
164
|
-
"n_steps": 10,
|
|
165
|
-
"shock_cov": np.array([[1.38, 0.58, -1.84], [0.58, 0.99, -0.82], [-1.84, -0.82, 2.51]]),
|
|
166
|
-
},
|
|
167
|
-
{
|
|
168
|
-
"shock_trajectory": np.r_[
|
|
169
|
-
np.zeros((3, 3), dtype=floatX),
|
|
170
|
-
np.array([[1.0, 0.0, 0.0]]).astype(floatX),
|
|
171
|
-
np.zeros((6, 3), dtype=floatX),
|
|
172
|
-
]
|
|
173
|
-
},
|
|
174
|
-
]
|
|
175
|
-
|
|
176
|
-
ids = ["from-posterior-cov", "scalar_shock_size", "array_shock_size", "user-cov", "trajectory"]
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@pytest.mark.parametrize("parameters", parameters, ids=ids)
|
|
180
|
-
@pytest.mark.skipif(floatX == "float32", reason="Impulse covariance not PSD if float32")
|
|
181
|
-
def test_impulse_response(parameters, varma_mod, idata, rng):
|
|
182
|
-
irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters)
|
|
183
|
-
|
|
184
|
-
assert not np.any(np.isnan(irf.irf.values))
|