pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,829 @@
|
|
|
1
|
+
import functools as ft
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import pymc as pm
|
|
10
|
+
import pytensor
|
|
11
|
+
import pytensor.tensor as pt
|
|
12
|
+
import pytest
|
|
13
|
+
import statsmodels.api as sm
|
|
14
|
+
|
|
15
|
+
from numpy.testing import assert_allclose
|
|
16
|
+
from scipy import linalg
|
|
17
|
+
|
|
18
|
+
from pymc_extras.statespace import structural as st
|
|
19
|
+
from pymc_extras.statespace.utils.constants import (
|
|
20
|
+
ALL_STATE_AUX_DIM,
|
|
21
|
+
ALL_STATE_DIM,
|
|
22
|
+
AR_PARAM_DIM,
|
|
23
|
+
OBS_STATE_AUX_DIM,
|
|
24
|
+
OBS_STATE_DIM,
|
|
25
|
+
SHOCK_AUX_DIM,
|
|
26
|
+
SHOCK_DIM,
|
|
27
|
+
SHORT_NAME_TO_LONG,
|
|
28
|
+
)
|
|
29
|
+
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
30
|
+
rng,
|
|
31
|
+
)
|
|
32
|
+
from tests.statespace.utilities.test_helpers import (
|
|
33
|
+
assert_pattern_repeats,
|
|
34
|
+
simulate_from_numpy_model,
|
|
35
|
+
unpack_symbolic_matrices_with_params,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
floatX = pytensor.config.floatX
|
|
39
|
+
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
|
|
40
|
+
RTOL = 0 if floatX.endswith("64") else 1e-6
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _assert_all_statespace_matrices_match(mod, params, sm_mod):
|
|
44
|
+
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, params)
|
|
45
|
+
|
|
46
|
+
sm_x0, sm_H0, sm_P0 = sm_mod.initialization()
|
|
47
|
+
|
|
48
|
+
if len(x0) > 0:
|
|
49
|
+
assert_allclose(x0, sm_x0)
|
|
50
|
+
|
|
51
|
+
for name, matrix in zip(["T", "R", "Z", "Q"], [T, R, Z, Q]):
|
|
52
|
+
long_name = SHORT_NAME_TO_LONG[name]
|
|
53
|
+
if np.any([x == 0 for x in matrix.shape]):
|
|
54
|
+
continue
|
|
55
|
+
assert_allclose(
|
|
56
|
+
sm_mod.ssm[long_name],
|
|
57
|
+
matrix,
|
|
58
|
+
err_msg=f"matrix {name} does not match statsmodels",
|
|
59
|
+
atol=ATOL,
|
|
60
|
+
rtol=RTOL,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _assert_coord_shapes_match_matrices(mod, params):
|
|
65
|
+
if "initial_state_cov" not in params:
|
|
66
|
+
params["initial_state_cov"] = np.eye(mod.k_states)
|
|
67
|
+
|
|
68
|
+
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, params)
|
|
69
|
+
|
|
70
|
+
n_states = len(mod.coords[ALL_STATE_DIM])
|
|
71
|
+
|
|
72
|
+
# There will always be one shock dimension -- dummies are inserted into fully deterministic models to avoid errors
|
|
73
|
+
# in the state space representation.
|
|
74
|
+
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
|
|
75
|
+
n_obs = len(mod.coords[OBS_STATE_DIM])
|
|
76
|
+
|
|
77
|
+
assert x0.shape[-1:] == (
|
|
78
|
+
n_states,
|
|
79
|
+
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
|
|
80
|
+
assert P0.shape[-2:] == (
|
|
81
|
+
n_states,
|
|
82
|
+
n_states,
|
|
83
|
+
), f"P0 expected to have shape (n_states, n_states), found {P0.shape[-2:]}"
|
|
84
|
+
assert c.shape[-1:] == (
|
|
85
|
+
n_states,
|
|
86
|
+
), f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
|
|
87
|
+
assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}"
|
|
88
|
+
assert T.shape[-2:] == (
|
|
89
|
+
n_states,
|
|
90
|
+
n_states,
|
|
91
|
+
), f"T expected to have shape (n_states, n_states), found {T.shape[-2:]}"
|
|
92
|
+
assert Z.shape[-2:] == (
|
|
93
|
+
n_obs,
|
|
94
|
+
n_states,
|
|
95
|
+
), f"Z expected to have shape (n_obs, n_states), found {Z.shape[-2:]}"
|
|
96
|
+
assert R.shape[-2:] == (
|
|
97
|
+
n_states,
|
|
98
|
+
n_shocks,
|
|
99
|
+
), f"R expected to have shape (n_states, n_shocks), found {R.shape[-2:]}"
|
|
100
|
+
assert H.shape[-2:] == (
|
|
101
|
+
n_obs,
|
|
102
|
+
n_obs,
|
|
103
|
+
), f"H expected to have shape (n_obs, n_obs), found {H.shape[-2:]}"
|
|
104
|
+
assert Q.shape[-2:] == (
|
|
105
|
+
n_shocks,
|
|
106
|
+
n_shocks,
|
|
107
|
+
), f"Q expected to have shape (n_shocks, n_shocks), found {Q.shape[-2:]}"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _assert_basic_coords_correct(mod):
|
|
111
|
+
assert mod.coords[ALL_STATE_DIM] == mod.state_names
|
|
112
|
+
assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names
|
|
113
|
+
assert mod.coords[SHOCK_DIM] == mod.shock_names
|
|
114
|
+
assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names
|
|
115
|
+
assert mod.coords[OBS_STATE_DIM] == ["data"]
|
|
116
|
+
assert mod.coords[OBS_STATE_AUX_DIM] == ["data"]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _assert_keys_match(test_dict, expected_dict):
|
|
120
|
+
expected_keys = list(expected_dict.keys())
|
|
121
|
+
param_keys = list(test_dict.keys())
|
|
122
|
+
key_diff = set(expected_keys) - set(param_keys)
|
|
123
|
+
assert len(key_diff) == 0, f'{", ".join(key_diff)} were not found in the test_dict keys.'
|
|
124
|
+
|
|
125
|
+
key_diff = set(param_keys) - set(expected_keys)
|
|
126
|
+
assert (
|
|
127
|
+
len(key_diff) == 0
|
|
128
|
+
), f'{", ".join(key_diff)} were keys of the tests_dict not in expected_dict.'
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _assert_param_dims_correct(param_dims, expected_dims):
|
|
132
|
+
if len(expected_dims) == 0 and len(param_dims) == 0:
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
_assert_keys_match(param_dims, expected_dims)
|
|
136
|
+
for param, dims in expected_dims.items():
|
|
137
|
+
assert dims == param_dims[param], f"dims for parameter {param} do not match"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _assert_coords_correct(coords, expected_coords):
|
|
141
|
+
if len(coords) == 0 and len(expected_coords) == 0:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
_assert_keys_match(coords, expected_coords)
|
|
145
|
+
for dim, labels in expected_coords.items():
|
|
146
|
+
assert labels == coords[dim], f"labels on dimension {dim} do not match"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _assert_params_info_correct(param_info, coords, param_dims):
|
|
150
|
+
for param in param_info.keys():
|
|
151
|
+
info = param_info[param]
|
|
152
|
+
|
|
153
|
+
dims = info["dims"]
|
|
154
|
+
labels = [coords[dim] for dim in dims] if dims is not None else None
|
|
155
|
+
if labels is not None:
|
|
156
|
+
assert param in param_dims.keys()
|
|
157
|
+
inferred_dims = param_dims[param]
|
|
158
|
+
else:
|
|
159
|
+
inferred_dims = None
|
|
160
|
+
|
|
161
|
+
shape = tuple(len(label) for label in labels) if labels is not None else ()
|
|
162
|
+
|
|
163
|
+
assert info["shape"] == shape
|
|
164
|
+
assert dims == inferred_dims
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def create_structural_model_and_equivalent_statsmodel(
|
|
168
|
+
rng,
|
|
169
|
+
level: bool | None = False,
|
|
170
|
+
trend: bool | None = False,
|
|
171
|
+
seasonal: int | None = None,
|
|
172
|
+
freq_seasonal: list[dict] | None = None,
|
|
173
|
+
cycle: bool = False,
|
|
174
|
+
autoregressive: int | None = None,
|
|
175
|
+
exog: np.ndarray | None = None,
|
|
176
|
+
irregular: bool | None = False,
|
|
177
|
+
stochastic_level: bool | None = True,
|
|
178
|
+
stochastic_trend: bool | None = False,
|
|
179
|
+
stochastic_seasonal: bool | None = True,
|
|
180
|
+
stochastic_freq_seasonal: list[bool] | None = None,
|
|
181
|
+
stochastic_cycle: bool | None = False,
|
|
182
|
+
damped_cycle: bool | None = False,
|
|
183
|
+
):
|
|
184
|
+
with warnings.catch_warnings():
|
|
185
|
+
warnings.simplefilter("ignore")
|
|
186
|
+
mod = ft.partial(
|
|
187
|
+
sm.tsa.UnobservedComponents,
|
|
188
|
+
level=level,
|
|
189
|
+
trend=trend,
|
|
190
|
+
seasonal=seasonal,
|
|
191
|
+
freq_seasonal=freq_seasonal,
|
|
192
|
+
cycle=cycle,
|
|
193
|
+
autoregressive=autoregressive,
|
|
194
|
+
exog=exog,
|
|
195
|
+
irregular=irregular,
|
|
196
|
+
stochastic_level=stochastic_level,
|
|
197
|
+
stochastic_trend=stochastic_trend,
|
|
198
|
+
stochastic_seasonal=stochastic_seasonal,
|
|
199
|
+
stochastic_freq_seasonal=stochastic_freq_seasonal,
|
|
200
|
+
stochastic_cycle=stochastic_cycle,
|
|
201
|
+
damped_cycle=damped_cycle,
|
|
202
|
+
mle_regression=False,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
params = {}
|
|
206
|
+
sm_params = {}
|
|
207
|
+
sm_init = {}
|
|
208
|
+
expected_param_dims = defaultdict(tuple)
|
|
209
|
+
expected_coords = defaultdict(list)
|
|
210
|
+
expected_param_dims["P0"] += ("state", "state_aux")
|
|
211
|
+
|
|
212
|
+
default_states = [
|
|
213
|
+
ALL_STATE_DIM,
|
|
214
|
+
ALL_STATE_AUX_DIM,
|
|
215
|
+
OBS_STATE_DIM,
|
|
216
|
+
OBS_STATE_AUX_DIM,
|
|
217
|
+
SHOCK_DIM,
|
|
218
|
+
SHOCK_AUX_DIM,
|
|
219
|
+
]
|
|
220
|
+
default_values = [[], [], ["data"], ["data"], [], []]
|
|
221
|
+
for dim, value in zip(default_states, default_values):
|
|
222
|
+
expected_coords[dim] += value
|
|
223
|
+
|
|
224
|
+
components = []
|
|
225
|
+
|
|
226
|
+
if irregular:
|
|
227
|
+
sigma2 = np.abs(rng.normal()).astype(floatX).item()
|
|
228
|
+
params["sigma_irregular"] = np.sqrt(sigma2)
|
|
229
|
+
sm_params["sigma2.irregular"] = sigma2
|
|
230
|
+
|
|
231
|
+
comp = st.MeasurementError("irregular")
|
|
232
|
+
components.append(comp)
|
|
233
|
+
|
|
234
|
+
level_trend_order = [0, 0]
|
|
235
|
+
level_trend_innov_order = [0, 0]
|
|
236
|
+
|
|
237
|
+
if level:
|
|
238
|
+
level_trend_order[0] = 1
|
|
239
|
+
expected_coords["trend_state"] += [
|
|
240
|
+
"level",
|
|
241
|
+
]
|
|
242
|
+
expected_coords[ALL_STATE_DIM] += [
|
|
243
|
+
"level",
|
|
244
|
+
]
|
|
245
|
+
expected_coords[ALL_STATE_AUX_DIM] += [
|
|
246
|
+
"level",
|
|
247
|
+
]
|
|
248
|
+
if stochastic_level:
|
|
249
|
+
level_trend_innov_order[0] = 1
|
|
250
|
+
expected_coords["trend_shock"] += ["level"]
|
|
251
|
+
expected_coords[SHOCK_DIM] += [
|
|
252
|
+
"level",
|
|
253
|
+
]
|
|
254
|
+
expected_coords[SHOCK_AUX_DIM] += [
|
|
255
|
+
"level",
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
if trend:
|
|
259
|
+
level_trend_order[1] = 1
|
|
260
|
+
expected_coords["trend_state"] += [
|
|
261
|
+
"trend",
|
|
262
|
+
]
|
|
263
|
+
expected_coords[ALL_STATE_DIM] += [
|
|
264
|
+
"trend",
|
|
265
|
+
]
|
|
266
|
+
expected_coords[ALL_STATE_AUX_DIM] += [
|
|
267
|
+
"trend",
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
if stochastic_trend:
|
|
271
|
+
level_trend_innov_order[1] = 1
|
|
272
|
+
expected_coords["trend_shock"] += ["trend"]
|
|
273
|
+
expected_coords[SHOCK_DIM] += ["trend"]
|
|
274
|
+
expected_coords[SHOCK_AUX_DIM] += ["trend"]
|
|
275
|
+
|
|
276
|
+
if level or trend:
|
|
277
|
+
expected_param_dims["initial_trend"] += ("trend_state",)
|
|
278
|
+
level_value = np.where(
|
|
279
|
+
level_trend_order,
|
|
280
|
+
rng.normal(
|
|
281
|
+
size=2,
|
|
282
|
+
).astype(floatX),
|
|
283
|
+
np.zeros(2, dtype=floatX),
|
|
284
|
+
)
|
|
285
|
+
sigma_level_value2 = np.abs(rng.normal(size=(2,)))[
|
|
286
|
+
np.array(level_trend_innov_order, dtype="bool")
|
|
287
|
+
]
|
|
288
|
+
max_order = np.flatnonzero(level_value)[-1].item() + 1
|
|
289
|
+
level_trend_order = level_trend_order[:max_order]
|
|
290
|
+
|
|
291
|
+
params["initial_trend"] = level_value[:max_order]
|
|
292
|
+
sm_init["level"] = level_value[0]
|
|
293
|
+
sm_init["trend"] = level_value[1]
|
|
294
|
+
|
|
295
|
+
if sum(level_trend_innov_order) > 0:
|
|
296
|
+
expected_param_dims["sigma_trend"] += ("trend_shock",)
|
|
297
|
+
params["sigma_trend"] = np.sqrt(sigma_level_value2)
|
|
298
|
+
|
|
299
|
+
sigma_level_value = sigma_level_value2.tolist()
|
|
300
|
+
if stochastic_level:
|
|
301
|
+
sigma = sigma_level_value.pop(0)
|
|
302
|
+
sm_params["sigma2.level"] = sigma
|
|
303
|
+
if stochastic_trend:
|
|
304
|
+
sigma = sigma_level_value.pop(0)
|
|
305
|
+
sm_params["sigma2.trend"] = sigma
|
|
306
|
+
|
|
307
|
+
comp = st.LevelTrendComponent(
|
|
308
|
+
name="level", order=level_trend_order, innovations_order=level_trend_innov_order
|
|
309
|
+
)
|
|
310
|
+
components.append(comp)
|
|
311
|
+
|
|
312
|
+
if seasonal is not None:
|
|
313
|
+
state_names = [f"seasonal_{i}" for i in range(seasonal)][1:]
|
|
314
|
+
seasonal_coefs = rng.normal(size=(seasonal - 1,)).astype(floatX)
|
|
315
|
+
params["seasonal_coefs"] = seasonal_coefs
|
|
316
|
+
expected_param_dims["seasonal_coefs"] += ("seasonal_state",)
|
|
317
|
+
|
|
318
|
+
expected_coords["seasonal_state"] += tuple(state_names)
|
|
319
|
+
expected_coords[ALL_STATE_DIM] += state_names
|
|
320
|
+
expected_coords[ALL_STATE_AUX_DIM] += state_names
|
|
321
|
+
|
|
322
|
+
seasonal_dict = {
|
|
323
|
+
"seasonal" if i == 0 else f"seasonal.L{i}": c for i, c in enumerate(seasonal_coefs)
|
|
324
|
+
}
|
|
325
|
+
sm_init.update(seasonal_dict)
|
|
326
|
+
|
|
327
|
+
if stochastic_seasonal:
|
|
328
|
+
sigma2 = np.abs(rng.normal()).astype(floatX)
|
|
329
|
+
params["sigma_seasonal"] = np.sqrt(sigma2)
|
|
330
|
+
sm_params["sigma2.seasonal"] = sigma2
|
|
331
|
+
expected_coords[SHOCK_DIM] += [
|
|
332
|
+
"seasonal",
|
|
333
|
+
]
|
|
334
|
+
expected_coords[SHOCK_AUX_DIM] += [
|
|
335
|
+
"seasonal",
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
comp = st.TimeSeasonality(
|
|
339
|
+
name="seasonal", season_length=seasonal, innovations=stochastic_seasonal
|
|
340
|
+
)
|
|
341
|
+
components.append(comp)
|
|
342
|
+
|
|
343
|
+
if freq_seasonal is not None:
|
|
344
|
+
state_count = 0
|
|
345
|
+
for d, has_innov in zip(freq_seasonal, stochastic_freq_seasonal):
|
|
346
|
+
n = d["harmonics"]
|
|
347
|
+
s = d["period"]
|
|
348
|
+
last_state_not_identified = (s / n) == 2.0
|
|
349
|
+
n_states = 2 * n - int(last_state_not_identified)
|
|
350
|
+
state_names = [f"seasonal_{s}_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
|
|
351
|
+
|
|
352
|
+
seasonal_params = rng.normal(size=n_states).astype(floatX)
|
|
353
|
+
|
|
354
|
+
params[f"seasonal_{s}"] = seasonal_params
|
|
355
|
+
expected_param_dims[f"seasonal_{s}"] += (f"seasonal_{s}_state",)
|
|
356
|
+
expected_coords[ALL_STATE_DIM] += state_names
|
|
357
|
+
expected_coords[ALL_STATE_AUX_DIM] += state_names
|
|
358
|
+
expected_coords[f"seasonal_{s}_state"] += (
|
|
359
|
+
tuple(state_names[:-1]) if last_state_not_identified else tuple(state_names)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
for param in seasonal_params:
|
|
363
|
+
sm_init[f"freq_seasonal.{state_count}"] = param
|
|
364
|
+
state_count += 1
|
|
365
|
+
if last_state_not_identified:
|
|
366
|
+
sm_init[f"freq_seasonal.{state_count}"] = 0.0
|
|
367
|
+
state_count += 1
|
|
368
|
+
|
|
369
|
+
if has_innov:
|
|
370
|
+
sigma2 = np.abs(rng.normal()).astype(floatX)
|
|
371
|
+
params[f"sigma_seasonal_{s}"] = np.sqrt(sigma2)
|
|
372
|
+
sm_params[f"sigma2.freq_seasonal_{s}({n})"] = sigma2
|
|
373
|
+
expected_coords[SHOCK_DIM] += state_names
|
|
374
|
+
expected_coords[SHOCK_AUX_DIM] += state_names
|
|
375
|
+
|
|
376
|
+
comp = st.FrequencySeasonality(
|
|
377
|
+
name=f"seasonal_{s}", season_length=s, n=n, innovations=has_innov
|
|
378
|
+
)
|
|
379
|
+
components.append(comp)
|
|
380
|
+
|
|
381
|
+
if cycle:
|
|
382
|
+
cycle_length = np.random.choice(np.arange(2, 12)).astype(floatX)
|
|
383
|
+
|
|
384
|
+
# Statsmodels takes the frequency not the cycle length, so convert it.
|
|
385
|
+
sm_params["frequency.cycle"] = 2.0 * np.pi / cycle_length
|
|
386
|
+
params["cycle_length"] = cycle_length
|
|
387
|
+
|
|
388
|
+
init_cycle = rng.normal(size=(2,)).astype(floatX)
|
|
389
|
+
params["cycle"] = init_cycle
|
|
390
|
+
expected_param_dims["cycle"] += ("cycle_state",)
|
|
391
|
+
|
|
392
|
+
state_names = ["cycle_Cos", "cycle_Sin"]
|
|
393
|
+
expected_coords["cycle_state"] += state_names
|
|
394
|
+
expected_coords[ALL_STATE_DIM] += state_names
|
|
395
|
+
expected_coords[ALL_STATE_AUX_DIM] += state_names
|
|
396
|
+
|
|
397
|
+
sm_init["cycle"] = init_cycle[0]
|
|
398
|
+
sm_init["cycle.auxilliary"] = init_cycle[1]
|
|
399
|
+
|
|
400
|
+
if stochastic_cycle:
|
|
401
|
+
sigma2 = np.abs(rng.normal()).astype(floatX)
|
|
402
|
+
params["sigma_cycle"] = np.sqrt(sigma2)
|
|
403
|
+
expected_coords[SHOCK_DIM] += state_names
|
|
404
|
+
expected_coords[SHOCK_AUX_DIM] += state_names
|
|
405
|
+
|
|
406
|
+
sm_params["sigma2.cycle"] = sigma2
|
|
407
|
+
|
|
408
|
+
if damped_cycle:
|
|
409
|
+
rho = rng.beta(1, 1)
|
|
410
|
+
params["cycle_dampening_factor"] = rho
|
|
411
|
+
sm_params["damping.cycle"] = rho
|
|
412
|
+
|
|
413
|
+
comp = st.CycleComponent(
|
|
414
|
+
name="cycle",
|
|
415
|
+
dampen=damped_cycle,
|
|
416
|
+
innovations=stochastic_cycle,
|
|
417
|
+
estimate_cycle_length=True,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
components.append(comp)
|
|
421
|
+
|
|
422
|
+
if autoregressive is not None:
|
|
423
|
+
ar_names = [f"L{i+1}.data" for i in range(autoregressive)]
|
|
424
|
+
ar_params = rng.normal(size=(autoregressive,)).astype(floatX)
|
|
425
|
+
if autoregressive == 1:
|
|
426
|
+
ar_params = ar_params.item()
|
|
427
|
+
sigma2 = np.abs(rng.normal()).astype(floatX)
|
|
428
|
+
|
|
429
|
+
params["ar_params"] = ar_params
|
|
430
|
+
params["sigma_ar"] = np.sqrt(sigma2)
|
|
431
|
+
expected_param_dims["ar_params"] += (AR_PARAM_DIM,)
|
|
432
|
+
expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
|
|
433
|
+
expected_coords[ALL_STATE_DIM] += ar_names
|
|
434
|
+
expected_coords[ALL_STATE_AUX_DIM] += ar_names
|
|
435
|
+
expected_coords[SHOCK_DIM] += ["ar_innovation"]
|
|
436
|
+
expected_coords[SHOCK_AUX_DIM] += ["ar_innovation"]
|
|
437
|
+
|
|
438
|
+
sm_params["sigma2.ar"] = sigma2
|
|
439
|
+
for i, rho in enumerate(ar_params):
|
|
440
|
+
sm_init[f"ar.L{i+1}"] = 0
|
|
441
|
+
sm_params[f"ar.L{i+1}"] = rho
|
|
442
|
+
|
|
443
|
+
comp = st.AutoregressiveComponent(name="ar", order=autoregressive)
|
|
444
|
+
components.append(comp)
|
|
445
|
+
|
|
446
|
+
if exog is not None:
|
|
447
|
+
names = [f"x{i + 1}" for i in range(exog.shape[1])]
|
|
448
|
+
betas = rng.normal(size=(exog.shape[1],)).astype(floatX)
|
|
449
|
+
params["beta_exog"] = betas
|
|
450
|
+
params["data_exog"] = exog
|
|
451
|
+
expected_param_dims["beta_exog"] += ("exog_state",)
|
|
452
|
+
expected_param_dims["data_exog"] += ("time", "exog_data")
|
|
453
|
+
|
|
454
|
+
expected_coords["exog_state"] += tuple(names)
|
|
455
|
+
|
|
456
|
+
for i, beta in enumerate(betas):
|
|
457
|
+
sm_params[f"beta.x{i + 1}"] = beta
|
|
458
|
+
sm_init[f"beta.x{i+1}"] = beta
|
|
459
|
+
comp = st.RegressionComponent(name="exog", state_names=names)
|
|
460
|
+
components.append(comp)
|
|
461
|
+
|
|
462
|
+
st_mod = components.pop(0)
|
|
463
|
+
for comp in components:
|
|
464
|
+
st_mod += comp
|
|
465
|
+
return mod, st_mod, params, sm_params, sm_init, expected_param_dims, expected_coords
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
@pytest.mark.parametrize(
|
|
469
|
+
"level, trend, stochastic_level, stochastic_trend, irregular",
|
|
470
|
+
[
|
|
471
|
+
(False, False, False, False, True),
|
|
472
|
+
(True, True, True, True, True),
|
|
473
|
+
(True, True, False, True, False),
|
|
474
|
+
],
|
|
475
|
+
)
|
|
476
|
+
@pytest.mark.parametrize("autoregressive", [None, 3])
|
|
477
|
+
@pytest.mark.parametrize("seasonal, stochastic_seasonal", [(None, False), (12, False), (12, True)])
|
|
478
|
+
@pytest.mark.parametrize(
|
|
479
|
+
"freq_seasonal, stochastic_freq_seasonal",
|
|
480
|
+
[
|
|
481
|
+
(None, None),
|
|
482
|
+
([{"period": 12, "harmonics": 2}], [False]),
|
|
483
|
+
([{"period": 12, "harmonics": 6}], [True]),
|
|
484
|
+
],
|
|
485
|
+
)
|
|
486
|
+
@pytest.mark.parametrize(
|
|
487
|
+
"cycle, damped_cycle, stochastic_cycle",
|
|
488
|
+
[(False, False, False), (True, False, True), (True, True, True)],
|
|
489
|
+
)
|
|
490
|
+
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.ConvergenceWarning")
|
|
491
|
+
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.SpecificationWarning")
|
|
492
|
+
def test_structural_model_against_statsmodels(
|
|
493
|
+
level,
|
|
494
|
+
trend,
|
|
495
|
+
stochastic_level,
|
|
496
|
+
stochastic_trend,
|
|
497
|
+
irregular,
|
|
498
|
+
autoregressive,
|
|
499
|
+
seasonal,
|
|
500
|
+
stochastic_seasonal,
|
|
501
|
+
freq_seasonal,
|
|
502
|
+
stochastic_freq_seasonal,
|
|
503
|
+
cycle,
|
|
504
|
+
damped_cycle,
|
|
505
|
+
stochastic_cycle,
|
|
506
|
+
rng,
|
|
507
|
+
):
|
|
508
|
+
retvals = create_structural_model_and_equivalent_statsmodel(
|
|
509
|
+
rng,
|
|
510
|
+
level=level,
|
|
511
|
+
trend=trend,
|
|
512
|
+
seasonal=seasonal,
|
|
513
|
+
freq_seasonal=freq_seasonal,
|
|
514
|
+
cycle=cycle,
|
|
515
|
+
damped_cycle=damped_cycle,
|
|
516
|
+
autoregressive=autoregressive,
|
|
517
|
+
irregular=irregular,
|
|
518
|
+
stochastic_level=stochastic_level,
|
|
519
|
+
stochastic_trend=stochastic_trend,
|
|
520
|
+
stochastic_seasonal=stochastic_seasonal,
|
|
521
|
+
stochastic_freq_seasonal=stochastic_freq_seasonal,
|
|
522
|
+
stochastic_cycle=stochastic_cycle,
|
|
523
|
+
)
|
|
524
|
+
f_sm_mod, mod, params, sm_params, sm_init, expected_dims, expected_coords = retvals
|
|
525
|
+
|
|
526
|
+
data = rng.normal(size=(100,)).astype(floatX)
|
|
527
|
+
sm_mod = f_sm_mod(data)
|
|
528
|
+
|
|
529
|
+
if len(sm_init) > 0:
|
|
530
|
+
init_array = np.concatenate(
|
|
531
|
+
[np.atleast_1d(sm_init[k]).ravel() for k in sm_mod.state_names if k != "dummy"]
|
|
532
|
+
)
|
|
533
|
+
sm_mod.initialize_known(init_array, np.eye(sm_mod.k_states))
|
|
534
|
+
else:
|
|
535
|
+
sm_mod.initialize_default()
|
|
536
|
+
|
|
537
|
+
if len(sm_params) > 0:
|
|
538
|
+
param_array = np.concatenate(
|
|
539
|
+
[np.atleast_1d(sm_params[k]).ravel() for k in sm_mod.param_names]
|
|
540
|
+
)
|
|
541
|
+
sm_mod.update(param_array, transformed=True)
|
|
542
|
+
|
|
543
|
+
_assert_all_statespace_matrices_match(mod, params, sm_mod)
|
|
544
|
+
|
|
545
|
+
built_model = mod.build(verbose=False)
|
|
546
|
+
|
|
547
|
+
_assert_coord_shapes_match_matrices(built_model, params)
|
|
548
|
+
_assert_param_dims_correct(built_model.param_dims, expected_dims)
|
|
549
|
+
_assert_coords_correct(built_model.coords, expected_coords)
|
|
550
|
+
_assert_params_info_correct(built_model.param_info, built_model.coords, built_model.param_dims)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def test_level_trend_model(rng):
|
|
554
|
+
mod = st.LevelTrendComponent(order=2, innovations_order=0)
|
|
555
|
+
params = {"initial_trend": [0.0, 1.0]}
|
|
556
|
+
x, y = simulate_from_numpy_model(mod, rng, params)
|
|
557
|
+
|
|
558
|
+
assert_allclose(np.diff(y), 1, atol=ATOL, rtol=RTOL)
|
|
559
|
+
|
|
560
|
+
# Check coords
|
|
561
|
+
mod = mod.build(verbose=False)
|
|
562
|
+
_assert_basic_coords_correct(mod)
|
|
563
|
+
assert mod.coords["trend_state"] == ["level", "trend"]
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def test_measurement_error(rng):
|
|
567
|
+
mod = st.MeasurementError("obs") + st.LevelTrendComponent(order=2)
|
|
568
|
+
mod = mod.build(verbose=False)
|
|
569
|
+
|
|
570
|
+
_assert_basic_coords_correct(mod)
|
|
571
|
+
assert "sigma_obs" in mod.param_names
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@pytest.mark.parametrize("order", [1, 2, [1, 0, 1]], ids=["AR1", "AR2", "AR(1,0,1)"])
|
|
575
|
+
def test_autoregressive_model(order, rng):
|
|
576
|
+
ar = st.AutoregressiveComponent(order=order)
|
|
577
|
+
params = {
|
|
578
|
+
"ar_params": np.full((sum(ar.order),), 0.5, dtype=floatX),
|
|
579
|
+
"sigma_ar": 0.0,
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
x, y = simulate_from_numpy_model(ar, rng, params, steps=100)
|
|
583
|
+
|
|
584
|
+
# Check coords
|
|
585
|
+
ar.build(verbose=False)
|
|
586
|
+
_assert_basic_coords_correct(ar)
|
|
587
|
+
lags = np.arange(len(order) if isinstance(order, list) else order, dtype="int") + 1
|
|
588
|
+
if isinstance(order, list):
|
|
589
|
+
lags = lags[np.flatnonzero(order)]
|
|
590
|
+
assert_allclose(ar.coords["ar_lag"], lags)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@pytest.mark.parametrize("s", [10, 25, 50])
|
|
594
|
+
@pytest.mark.parametrize("innovations", [True, False])
|
|
595
|
+
def test_time_seasonality(s, innovations, rng):
|
|
596
|
+
def random_word(rng):
|
|
597
|
+
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
|
|
598
|
+
|
|
599
|
+
state_names = [random_word(rng) for _ in range(s)]
|
|
600
|
+
mod = st.TimeSeasonality(
|
|
601
|
+
season_length=s, innovations=innovations, name="season", state_names=state_names
|
|
602
|
+
)
|
|
603
|
+
x0 = np.zeros(mod.k_states, dtype=floatX)
|
|
604
|
+
x0[0] = 1
|
|
605
|
+
|
|
606
|
+
params = {"season_coefs": x0}
|
|
607
|
+
if mod.innovations:
|
|
608
|
+
params["sigma_season"] = 0.0
|
|
609
|
+
|
|
610
|
+
x, y = simulate_from_numpy_model(mod, rng, params)
|
|
611
|
+
y = y.ravel()
|
|
612
|
+
if not innovations:
|
|
613
|
+
assert_pattern_repeats(y, s, atol=ATOL, rtol=RTOL)
|
|
614
|
+
|
|
615
|
+
# Check coords
|
|
616
|
+
mod.build(verbose=False)
|
|
617
|
+
_assert_basic_coords_correct(mod)
|
|
618
|
+
assert mod.coords["season_state"] == state_names[1:]
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def get_shift_factor(s):
|
|
622
|
+
s_str = str(s)
|
|
623
|
+
if "." not in s_str:
|
|
624
|
+
return 1
|
|
625
|
+
_, decimal = s_str.split(".")
|
|
626
|
+
return 10 ** len(decimal)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@pytest.mark.parametrize("n", [*np.arange(1, 6, dtype="int").tolist(), None])
|
|
630
|
+
@pytest.mark.parametrize("s", [5, 10, 25, 25.2])
|
|
631
|
+
def test_frequency_seasonality(n, s, rng):
|
|
632
|
+
mod = st.FrequencySeasonality(season_length=s, n=n, name="season")
|
|
633
|
+
x0 = rng.normal(size=mod.n_coefs).astype(floatX)
|
|
634
|
+
params = {"season": x0, "sigma_season": 0.0}
|
|
635
|
+
k = get_shift_factor(s)
|
|
636
|
+
T = int(s * k)
|
|
637
|
+
|
|
638
|
+
x, y = simulate_from_numpy_model(mod, rng, params, steps=2 * T)
|
|
639
|
+
assert_pattern_repeats(y, T, atol=ATOL, rtol=RTOL)
|
|
640
|
+
|
|
641
|
+
# Check coords
|
|
642
|
+
mod.build(verbose=False)
|
|
643
|
+
_assert_basic_coords_correct(mod)
|
|
644
|
+
if n is None:
|
|
645
|
+
n = int(s // 2)
|
|
646
|
+
states = [f"season_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
|
|
647
|
+
|
|
648
|
+
# Remove the last state when the model is completely saturated
|
|
649
|
+
if s / n == 2.0:
|
|
650
|
+
states.pop()
|
|
651
|
+
assert mod.coords["season_state"] == states
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
cycle_test_vals = zip([None, None, 3, 5, 10], [False, True, True, False, False])
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def test_cycle_component_deterministic(rng):
|
|
658
|
+
cycle = st.CycleComponent(
|
|
659
|
+
name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False
|
|
660
|
+
)
|
|
661
|
+
params = {"cycle": np.array([1.0, 1.0], dtype=floatX)}
|
|
662
|
+
x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
|
|
663
|
+
|
|
664
|
+
assert_pattern_repeats(y, 12, atol=ATOL, rtol=RTOL)
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def test_cycle_component_with_dampening(rng):
|
|
668
|
+
cycle = st.CycleComponent(
|
|
669
|
+
name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False, dampen=True
|
|
670
|
+
)
|
|
671
|
+
params = {"cycle": np.array([10.0, 10.0], dtype=floatX), "cycle_dampening_factor": 0.75}
|
|
672
|
+
x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
|
|
673
|
+
|
|
674
|
+
# Check that the cycle dampens to zero over time
|
|
675
|
+
assert_allclose(y[-1], 0.0, atol=ATOL, rtol=RTOL)
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def test_cycle_component_with_innovations_and_cycle_length(rng):
|
|
679
|
+
cycle = st.CycleComponent(
|
|
680
|
+
name="cycle", estimate_cycle_length=True, innovations=True, dampen=True
|
|
681
|
+
)
|
|
682
|
+
params = {
|
|
683
|
+
"cycle": np.array([1.0, 1.0], dtype=floatX),
|
|
684
|
+
"cycle_length": 12.0,
|
|
685
|
+
"cycle_dampening_factor": 0.95,
|
|
686
|
+
"sigma_cycle": 1.0,
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
x, y = simulate_from_numpy_model(cycle, rng, params)
|
|
690
|
+
|
|
691
|
+
cycle.build(verbose=False)
|
|
692
|
+
_assert_basic_coords_correct(cycle)
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def test_exogenous_component(rng):
|
|
696
|
+
data = rng.normal(size=(100, 2)).astype(floatX)
|
|
697
|
+
mod = st.RegressionComponent(state_names=["feature_1", "feature_2"], name="exog")
|
|
698
|
+
|
|
699
|
+
params = {"beta_exog": np.array([1.0, 2.0], dtype=floatX)}
|
|
700
|
+
exog_data = {"data_exog": data}
|
|
701
|
+
x, y = simulate_from_numpy_model(mod, rng, params, exog_data)
|
|
702
|
+
|
|
703
|
+
# Check that the generated data is just a linear regression
|
|
704
|
+
assert_allclose(y, data @ params["beta_exog"], atol=ATOL, rtol=RTOL)
|
|
705
|
+
|
|
706
|
+
mod.build(verbose=False)
|
|
707
|
+
_assert_basic_coords_correct(mod)
|
|
708
|
+
assert mod.coords["exog_state"] == ["feature_1", "feature_2"]
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def test_adding_exogenous_component(rng):
|
|
712
|
+
data = rng.normal(size=(100, 2)).astype(floatX)
|
|
713
|
+
reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
|
|
714
|
+
ll = st.LevelTrendComponent(name="level")
|
|
715
|
+
|
|
716
|
+
seasonal = st.FrequencySeasonality(name="annual", season_length=12, n=4)
|
|
717
|
+
mod = reg + ll + seasonal
|
|
718
|
+
|
|
719
|
+
assert mod.ssm["design"].eval({"data_exog": data}).shape == (100, 1, 2 + 2 + 8)
|
|
720
|
+
assert_allclose(mod.ssm["design", 5, 0, :2].eval({"data_exog": data}), data[5])
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def test_add_components():
|
|
724
|
+
ll = st.LevelTrendComponent(order=2)
|
|
725
|
+
se = st.TimeSeasonality(name="seasonal", season_length=12)
|
|
726
|
+
mod = ll + se
|
|
727
|
+
|
|
728
|
+
ll_params = {
|
|
729
|
+
"initial_trend": np.zeros(2, dtype=floatX),
|
|
730
|
+
"sigma_trend": np.ones(2, dtype=floatX),
|
|
731
|
+
}
|
|
732
|
+
se_params = {
|
|
733
|
+
"seasonal_coefs": np.ones(11, dtype=floatX),
|
|
734
|
+
"sigma_seasonal": 1.0,
|
|
735
|
+
}
|
|
736
|
+
all_params = ll_params.copy()
|
|
737
|
+
all_params.update(se_params)
|
|
738
|
+
|
|
739
|
+
(ll_x0, ll_P0, ll_c, ll_d, ll_T, ll_Z, ll_R, ll_H, ll_Q) = unpack_symbolic_matrices_with_params(
|
|
740
|
+
ll, ll_params
|
|
741
|
+
)
|
|
742
|
+
(se_x0, se_P0, se_c, se_d, se_T, se_Z, se_R, se_H, se_Q) = unpack_symbolic_matrices_with_params(
|
|
743
|
+
se, se_params
|
|
744
|
+
)
|
|
745
|
+
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, all_params)
|
|
746
|
+
|
|
747
|
+
for property in ["param_names", "shock_names", "param_info", "coords", "param_dims"]:
|
|
748
|
+
assert [x in getattr(mod, property) for x in getattr(ll, property)]
|
|
749
|
+
assert [x in getattr(mod, property) for x in getattr(se, property)]
|
|
750
|
+
|
|
751
|
+
ll_mats = [ll_T, ll_R, ll_Q]
|
|
752
|
+
se_mats = [se_T, se_R, se_Q]
|
|
753
|
+
all_mats = [T, R, Q]
|
|
754
|
+
|
|
755
|
+
for ll_mat, se_mat, all_mat in zip(ll_mats, se_mats, all_mats):
|
|
756
|
+
assert_allclose(all_mat, linalg.block_diag(ll_mat, se_mat), atol=ATOL, rtol=RTOL)
|
|
757
|
+
|
|
758
|
+
ll_mats = [ll_x0, ll_c, ll_Z]
|
|
759
|
+
se_mats = [se_x0, se_c, se_Z]
|
|
760
|
+
all_mats = [x0, c, Z]
|
|
761
|
+
axes = [0, 0, 1]
|
|
762
|
+
|
|
763
|
+
for ll_mat, se_mat, all_mat, axis in zip(ll_mats, se_mats, all_mats, axes):
|
|
764
|
+
assert_allclose(all_mat, np.concatenate([ll_mat, se_mat], axis=axis), atol=ATOL, rtol=RTOL)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def test_filter_scans_time_varying_design_matrix(rng):
|
|
768
|
+
time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
|
|
769
|
+
data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)
|
|
770
|
+
|
|
771
|
+
y = pd.DataFrame(rng.normal(size=(100, 1)), columns=["data"], index=time_idx)
|
|
772
|
+
|
|
773
|
+
reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
|
|
774
|
+
mod = reg.build(verbose=False)
|
|
775
|
+
|
|
776
|
+
with pm.Model(coords=mod.coords) as m:
|
|
777
|
+
data_exog = pm.Data("data_exog", data.values)
|
|
778
|
+
|
|
779
|
+
x0 = pm.Normal("x0", dims=["state"])
|
|
780
|
+
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
|
|
781
|
+
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
|
|
782
|
+
|
|
783
|
+
mod.build_statespace_graph(y)
|
|
784
|
+
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
|
|
785
|
+
pm.Deterministic("Z", Z)
|
|
786
|
+
|
|
787
|
+
prior = pm.sample_prior_predictive(draws=10)
|
|
788
|
+
|
|
789
|
+
prior_Z = prior.prior.Z.values
|
|
790
|
+
assert prior_Z.shape == (1, 10, 100, 1, 2)
|
|
791
|
+
assert_allclose(prior_Z[0, :, :, 0, :], data.values[None].repeat(10, axis=0))
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
@pytest.mark.skipif(floatX.endswith("32"), reason="Prior covariance not PSD at half-precision")
|
|
795
|
+
def test_extract_components_from_idata(rng):
|
|
796
|
+
time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
|
|
797
|
+
data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)
|
|
798
|
+
|
|
799
|
+
y = pd.DataFrame(rng.normal(size=(100, 1)), columns=["data"], index=time_idx)
|
|
800
|
+
|
|
801
|
+
ll = st.LevelTrendComponent()
|
|
802
|
+
season = st.FrequencySeasonality(name="seasonal", season_length=12, n=2, innovations=False)
|
|
803
|
+
reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
|
|
804
|
+
me = st.MeasurementError("obs")
|
|
805
|
+
mod = (ll + season + reg + me).build(verbose=False)
|
|
806
|
+
|
|
807
|
+
with pm.Model(coords=mod.coords) as m:
|
|
808
|
+
data_exog = pm.Data("data_exog", data.values)
|
|
809
|
+
|
|
810
|
+
x0 = pm.Normal("x0", dims=["state"])
|
|
811
|
+
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
|
|
812
|
+
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
|
|
813
|
+
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
|
|
814
|
+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
|
|
815
|
+
seasonal_coefs = pm.Normal("seasonal", dims=["seasonal_state"])
|
|
816
|
+
sigma_obs = pm.Exponential("sigma_obs", 1)
|
|
817
|
+
|
|
818
|
+
mod.build_statespace_graph(y)
|
|
819
|
+
|
|
820
|
+
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
|
|
821
|
+
prior = pm.sample_prior_predictive(draws=10)
|
|
822
|
+
|
|
823
|
+
filter_prior = mod.sample_conditional_prior(prior)
|
|
824
|
+
comp_prior = mod.extract_components_from_idata(filter_prior)
|
|
825
|
+
comp_states = comp_prior.filtered_prior.coords["state"].values
|
|
826
|
+
expected_states = ["LevelTrend[level]", "LevelTrend[trend]", "seasonal", "exog[a]", "exog[b]"]
|
|
827
|
+
missing = set(comp_states) - set(expected_states)
|
|
828
|
+
|
|
829
|
+
assert len(missing) == 0, missing
|