pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/deserialize.py +224 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/prior.py +1356 -0
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.7.dist-info/METADATA +321 -0
- pymc_extras-0.2.7.dist-info/RECORD +66 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
- pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,872 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
import pymc as pm
|
|
7
|
-
import pytensor
|
|
8
|
-
import pytensor.tensor as pt
|
|
9
|
-
import pytest
|
|
10
|
-
|
|
11
|
-
from numpy.testing import assert_allclose
|
|
12
|
-
|
|
13
|
-
from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
|
|
14
|
-
from pymc_extras.statespace.models import structural as st
|
|
15
|
-
from pymc_extras.statespace.models.utilities import make_default_coords
|
|
16
|
-
from pymc_extras.statespace.utils.constants import (
|
|
17
|
-
FILTER_OUTPUT_NAMES,
|
|
18
|
-
MATRIX_NAMES,
|
|
19
|
-
SMOOTHER_OUTPUT_NAMES,
|
|
20
|
-
)
|
|
21
|
-
from tests.statespace.utilities.shared_fixtures import (
|
|
22
|
-
rng,
|
|
23
|
-
)
|
|
24
|
-
from tests.statespace.utilities.test_helpers import (
|
|
25
|
-
fast_eval,
|
|
26
|
-
load_nile_test_data,
|
|
27
|
-
make_test_inputs,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
floatX = pytensor.config.floatX
|
|
31
|
-
nile = load_nile_test_data()
|
|
32
|
-
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
|
|
36
|
-
class StateSpace(PyMCStateSpace):
|
|
37
|
-
def make_symbolic_graph(self):
|
|
38
|
-
pass
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def data_info(self):
|
|
42
|
-
return data_info
|
|
43
|
-
|
|
44
|
-
ss = StateSpace(
|
|
45
|
-
k_states=k_states,
|
|
46
|
-
k_endog=k_endog,
|
|
47
|
-
k_posdef=k_posdef,
|
|
48
|
-
filter_type=filter_type,
|
|
49
|
-
verbose=verbose,
|
|
50
|
-
)
|
|
51
|
-
ss._needs_exog_data = data_info is not None
|
|
52
|
-
ss._exog_names = list(data_info.keys()) if data_info is not None else []
|
|
53
|
-
|
|
54
|
-
return ss
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@pytest.fixture(scope="session")
|
|
58
|
-
def ss_mod():
|
|
59
|
-
class StateSpace(PyMCStateSpace):
|
|
60
|
-
@property
|
|
61
|
-
def param_names(self):
|
|
62
|
-
return ["rho", "zeta"]
|
|
63
|
-
|
|
64
|
-
@property
|
|
65
|
-
def state_names(self):
|
|
66
|
-
return ["a", "b"]
|
|
67
|
-
|
|
68
|
-
@property
|
|
69
|
-
def observed_states(self):
|
|
70
|
-
return ["a"]
|
|
71
|
-
|
|
72
|
-
@property
|
|
73
|
-
def shock_names(self):
|
|
74
|
-
return ["a"]
|
|
75
|
-
|
|
76
|
-
@property
|
|
77
|
-
def coords(self):
|
|
78
|
-
return make_default_coords(self)
|
|
79
|
-
|
|
80
|
-
def make_symbolic_graph(self):
|
|
81
|
-
rho = self.make_and_register_variable("rho", ())
|
|
82
|
-
zeta = self.make_and_register_variable("zeta", ())
|
|
83
|
-
self.ssm["transition", 0, 0] = rho
|
|
84
|
-
self.ssm["transition", 1, 0] = zeta
|
|
85
|
-
|
|
86
|
-
Z = np.array([[1.0, 0.0]], dtype=floatX)
|
|
87
|
-
R = np.array([[1.0], [0.0]], dtype=floatX)
|
|
88
|
-
H = np.array([[0.1]], dtype=floatX)
|
|
89
|
-
Q = np.array([[0.8]], dtype=floatX)
|
|
90
|
-
P0 = np.eye(2, dtype=floatX) * 1e6
|
|
91
|
-
|
|
92
|
-
ss_mod = StateSpace(
|
|
93
|
-
k_endog=nile.shape[1], k_states=2, k_posdef=1, filter_type="standard", verbose=False
|
|
94
|
-
)
|
|
95
|
-
for X, name in zip(
|
|
96
|
-
[Z, R, H, Q, P0],
|
|
97
|
-
["design", "selection", "obs_cov", "state_cov", "initial_state_cov"],
|
|
98
|
-
):
|
|
99
|
-
ss_mod.ssm[name] = X
|
|
100
|
-
|
|
101
|
-
return ss_mod
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@pytest.fixture(scope="session")
|
|
105
|
-
def pymc_mod(ss_mod):
|
|
106
|
-
with pm.Model(coords=ss_mod.coords) as pymc_mod:
|
|
107
|
-
rho = pm.Beta("rho", 1, 1)
|
|
108
|
-
zeta = pm.Deterministic("zeta", 1 - rho)
|
|
109
|
-
|
|
110
|
-
ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True)
|
|
111
|
-
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
|
|
112
|
-
for name, matrix in zip(names, ss_mod.unpack_statespace()):
|
|
113
|
-
pm.Deterministic(name, matrix)
|
|
114
|
-
|
|
115
|
-
return pymc_mod
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
@pytest.fixture(scope="session")
|
|
119
|
-
def ss_mod_no_exog(rng):
|
|
120
|
-
ll = st.LevelTrendComponent(order=2, innovations_order=1)
|
|
121
|
-
return ll.build()
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@pytest.fixture(scope="session")
|
|
125
|
-
def ss_mod_no_exog_dt(rng):
|
|
126
|
-
ll = st.LevelTrendComponent(order=2, innovations_order=1)
|
|
127
|
-
return ll.build()
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@pytest.fixture(scope="session")
|
|
131
|
-
def exog_ss_mod(rng):
|
|
132
|
-
ll = st.LevelTrendComponent()
|
|
133
|
-
reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
|
|
134
|
-
mod = (ll + reg).build(verbose=False)
|
|
135
|
-
|
|
136
|
-
return mod
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
@pytest.fixture(scope="session")
|
|
140
|
-
def exog_pymc_mod(exog_ss_mod, rng):
|
|
141
|
-
y = rng.normal(size=(100, 1)).astype(floatX)
|
|
142
|
-
X = rng.normal(size=(100, 3)).astype(floatX)
|
|
143
|
-
|
|
144
|
-
with pm.Model(coords=exog_ss_mod.coords) as m:
|
|
145
|
-
exog_data = pm.Data("data_exog", X)
|
|
146
|
-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
|
|
147
|
-
P0_sigma = pm.Exponential("P0_sigma", 1)
|
|
148
|
-
P0 = pm.Deterministic(
|
|
149
|
-
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
|
|
150
|
-
)
|
|
151
|
-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
|
|
152
|
-
|
|
153
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
|
|
154
|
-
exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
|
|
155
|
-
|
|
156
|
-
return m
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
@pytest.fixture(scope="session")
|
|
160
|
-
def pymc_mod_no_exog(ss_mod_no_exog, rng):
|
|
161
|
-
y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
|
|
162
|
-
|
|
163
|
-
with pm.Model(coords=ss_mod_no_exog.coords) as m:
|
|
164
|
-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
|
|
165
|
-
P0_sigma = pm.Exponential("P0_sigma", 1)
|
|
166
|
-
P0 = pm.Deterministic(
|
|
167
|
-
"P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
|
|
168
|
-
)
|
|
169
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
|
|
170
|
-
ss_mod_no_exog.build_statespace_graph(y)
|
|
171
|
-
|
|
172
|
-
return m
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
@pytest.fixture(scope="session")
|
|
176
|
-
def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
|
|
177
|
-
y = pd.DataFrame(
|
|
178
|
-
rng.normal(size=(100, 1)).astype(floatX),
|
|
179
|
-
columns=["y"],
|
|
180
|
-
index=pd.date_range("2020-01-01", periods=100, freq="D"),
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
with pm.Model(coords=ss_mod_no_exog_dt.coords) as m:
|
|
184
|
-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
|
|
185
|
-
P0_sigma = pm.Exponential("P0_sigma", 1)
|
|
186
|
-
P0 = pm.Deterministic(
|
|
187
|
-
"P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
|
|
188
|
-
)
|
|
189
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
|
|
190
|
-
ss_mod_no_exog_dt.build_statespace_graph(y)
|
|
191
|
-
|
|
192
|
-
return m
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@pytest.fixture(scope="session")
|
|
196
|
-
def idata(pymc_mod, rng):
|
|
197
|
-
with pymc_mod:
|
|
198
|
-
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
|
|
199
|
-
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
|
|
200
|
-
|
|
201
|
-
idata.extend(idata_prior)
|
|
202
|
-
return idata
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
@pytest.fixture(scope="session")
|
|
206
|
-
def idata_exog(exog_pymc_mod, rng):
|
|
207
|
-
with exog_pymc_mod:
|
|
208
|
-
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
|
|
209
|
-
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
|
|
210
|
-
idata.extend(idata_prior)
|
|
211
|
-
return idata
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@pytest.fixture(scope="session")
|
|
215
|
-
def idata_no_exog(pymc_mod_no_exog, rng):
|
|
216
|
-
with pymc_mod_no_exog:
|
|
217
|
-
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
|
|
218
|
-
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
|
|
219
|
-
idata.extend(idata_prior)
|
|
220
|
-
return idata
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
@pytest.fixture(scope="session")
|
|
224
|
-
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
|
|
225
|
-
with pymc_mod_no_exog_dt:
|
|
226
|
-
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
|
|
227
|
-
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
|
|
228
|
-
idata.extend(idata_prior)
|
|
229
|
-
return idata
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def test_invalid_filter_name_raises():
|
|
233
|
-
msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys()))
|
|
234
|
-
with pytest.raises(NotImplementedError, match=msg):
|
|
235
|
-
mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter")
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def test_unpack_before_insert_raises(rng):
|
|
239
|
-
p, m, r, n = 2, 5, 1, 10
|
|
240
|
-
data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
|
|
241
|
-
mod = make_statespace_mod(
|
|
242
|
-
k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
msg = "Cannot unpack the complete statespace system until PyMC model variables have been inserted."
|
|
246
|
-
with pytest.raises(ValueError, match=msg):
|
|
247
|
-
outputs = mod.unpack_statespace()
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
def test_unpack_matrices(rng):
|
|
251
|
-
p, m, r, n = 2, 5, 1, 10
|
|
252
|
-
data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
|
|
253
|
-
mod = make_statespace_mod(
|
|
254
|
-
k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# mod is a dummy statespace, so there are no placeholders to worry about. Monkey patch subbed_ssm with the defaults
|
|
258
|
-
mod.subbed_ssm = mod._unpack_statespace_with_placeholders()
|
|
259
|
-
|
|
260
|
-
outputs = mod.unpack_statespace()
|
|
261
|
-
for x, y in zip(inputs, outputs):
|
|
262
|
-
assert_allclose(np.zeros_like(x), fast_eval(y))
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def test_param_names_raises_on_base_class():
|
|
266
|
-
mod = make_statespace_mod(
|
|
267
|
-
k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
|
|
268
|
-
)
|
|
269
|
-
with pytest.raises(NotImplementedError):
|
|
270
|
-
x = mod.param_names
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
def test_base_class_raises():
|
|
274
|
-
with pytest.raises(NotImplementedError):
|
|
275
|
-
mod = PyMCStateSpace(
|
|
276
|
-
k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
def test_update_raises_if_missing_variables(ss_mod):
|
|
281
|
-
with pm.Model() as mod:
|
|
282
|
-
rho = pm.Normal("rho")
|
|
283
|
-
msg = "The following required model parameters were not found in the PyMC model: zeta"
|
|
284
|
-
with pytest.raises(ValueError, match=msg):
|
|
285
|
-
ss_mod._insert_random_variables()
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def test_build_statespace_graph_warns_if_data_has_nans():
|
|
289
|
-
# Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
|
|
290
|
-
ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
|
|
291
|
-
|
|
292
|
-
with pm.Model() as pymc_mod:
|
|
293
|
-
initial_trend = pm.Normal("initial_trend", shape=(1,))
|
|
294
|
-
P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
|
|
295
|
-
with pytest.warns(pm.ImputationWarning):
|
|
296
|
-
ss_mod.build_statespace_graph(
|
|
297
|
-
data=np.full((10, 1), np.nan, dtype=floatX), register_data=False
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def test_build_statespace_graph_raises_if_data_has_missing_fill():
|
|
302
|
-
# Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
|
|
303
|
-
ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
|
|
304
|
-
|
|
305
|
-
with pm.Model() as pymc_mod:
|
|
306
|
-
initial_trend = pm.Normal("initial_trend", shape=(1,))
|
|
307
|
-
P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
|
|
308
|
-
with pytest.raises(ValueError, match="Provided data contains the value 1.0"):
|
|
309
|
-
data = np.ones((10, 1), dtype=floatX)
|
|
310
|
-
data[3] = np.nan
|
|
311
|
-
ss_mod.build_statespace_graph(data=data, missing_fill_value=1.0, register_data=False)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
def test_build_statespace_graph(pymc_mod):
|
|
315
|
-
for name in [
|
|
316
|
-
"filtered_state",
|
|
317
|
-
"predicted_state",
|
|
318
|
-
"predicted_covariance",
|
|
319
|
-
"filtered_covariance",
|
|
320
|
-
]:
|
|
321
|
-
assert name in [x.name for x in pymc_mod.deterministics]
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def test_build_smoother_graph(ss_mod, pymc_mod):
|
|
325
|
-
names = ["smoothed_state", "smoothed_covariance"]
|
|
326
|
-
for name in names:
|
|
327
|
-
assert name in [x.name for x in pymc_mod.deterministics]
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
@pytest.mark.parametrize("group", ["posterior", "prior"])
|
|
331
|
-
@pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
|
|
332
|
-
def test_no_nans_in_sampling_output(group, matrix, idata):
|
|
333
|
-
assert not np.any(np.isnan(idata[group][matrix].values))
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
@pytest.mark.parametrize("group", ["posterior", "prior"])
|
|
337
|
-
@pytest.mark.parametrize("kind", ["conditional", "unconditional"])
|
|
338
|
-
def test_sampling_methods(group, kind, ss_mod, idata, rng):
|
|
339
|
-
f = getattr(ss_mod, f"sample_{kind}_{group}")
|
|
340
|
-
test_idata = f(idata, random_seed=rng)
|
|
341
|
-
|
|
342
|
-
if kind == "conditional":
|
|
343
|
-
for output in ["filtered", "predicted", "smoothed"]:
|
|
344
|
-
assert f"{output}_{group}" in test_idata
|
|
345
|
-
assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
|
|
346
|
-
assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
|
|
347
|
-
if kind == "unconditional":
|
|
348
|
-
for output in ["latent", "observed"]:
|
|
349
|
-
assert f"{group}_{output}" in test_idata
|
|
350
|
-
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
|
|
354
|
-
def test_sample_conditional_with_time_varying():
|
|
355
|
-
class TVCovariance(PyMCStateSpace):
|
|
356
|
-
def __init__(self):
|
|
357
|
-
super().__init__(k_states=1, k_endog=1, k_posdef=1)
|
|
358
|
-
|
|
359
|
-
def make_symbolic_graph(self) -> None:
|
|
360
|
-
self.ssm["transition", 0, 0] = 1.0
|
|
361
|
-
|
|
362
|
-
self.ssm["design", 0, 0] = 1.0
|
|
363
|
-
|
|
364
|
-
sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
|
|
365
|
-
self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
|
|
366
|
-
|
|
367
|
-
@property
|
|
368
|
-
def param_names(self) -> list[str]:
|
|
369
|
-
return ["sigma_cov"]
|
|
370
|
-
|
|
371
|
-
@property
|
|
372
|
-
def coords(self) -> dict[str, Sequence[str]]:
|
|
373
|
-
return make_default_coords(self)
|
|
374
|
-
|
|
375
|
-
@property
|
|
376
|
-
def state_names(self) -> list[str]:
|
|
377
|
-
return ["level"]
|
|
378
|
-
|
|
379
|
-
@property
|
|
380
|
-
def observed_states(self) -> list[str]:
|
|
381
|
-
return ["level"]
|
|
382
|
-
|
|
383
|
-
@property
|
|
384
|
-
def shock_names(self) -> list[str]:
|
|
385
|
-
return ["level"]
|
|
386
|
-
|
|
387
|
-
ss_mod = TVCovariance()
|
|
388
|
-
empty_data = pd.DataFrame(
|
|
389
|
-
np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
coords = ss_mod.coords
|
|
393
|
-
coords["time"] = empty_data.index
|
|
394
|
-
with pm.Model(coords=coords) as mod:
|
|
395
|
-
log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
|
|
396
|
-
pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
|
|
397
|
-
|
|
398
|
-
ss_mod.build_statespace_graph(data=empty_data)
|
|
399
|
-
|
|
400
|
-
prior = pm.sample_prior_predictive(10)
|
|
401
|
-
|
|
402
|
-
ss_mod.sample_unconditional_prior(prior)
|
|
403
|
-
ss_mod.sample_conditional_prior(prior)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def _make_time_idx(mod, use_datetime_index=True):
|
|
407
|
-
if use_datetime_index:
|
|
408
|
-
mod._fit_coords["time"] = nile.index
|
|
409
|
-
time_idx = nile.index
|
|
410
|
-
else:
|
|
411
|
-
mod._fit_coords["time"] = nile.reset_index().index
|
|
412
|
-
time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1)
|
|
413
|
-
|
|
414
|
-
return time_idx
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
@pytest.mark.parametrize("use_datetime_index", [True, False])
|
|
418
|
-
def test_bad_forecast_arguments(use_datetime_index, caplog):
|
|
419
|
-
ss_mod = make_statespace_mod(
|
|
420
|
-
k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
|
|
421
|
-
)
|
|
422
|
-
|
|
423
|
-
# Not-fit model raises
|
|
424
|
-
ss_mod._fit_coords = dict()
|
|
425
|
-
with pytest.raises(ValueError, match="Has this model been fit?"):
|
|
426
|
-
ss_mod._get_fit_time_index()
|
|
427
|
-
|
|
428
|
-
time_idx = _make_time_idx(ss_mod, use_datetime_index)
|
|
429
|
-
|
|
430
|
-
# Start value not in time index
|
|
431
|
-
match = (
|
|
432
|
-
"Datetime start must be in the data index used to fit the model"
|
|
433
|
-
if use_datetime_index
|
|
434
|
-
else "Integer start must be within the range of the data index used to fit the model."
|
|
435
|
-
)
|
|
436
|
-
with pytest.raises(ValueError, match=match):
|
|
437
|
-
start = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11
|
|
438
|
-
ss_mod._validate_forecast_args(time_index=time_idx, start=start, periods=10)
|
|
439
|
-
|
|
440
|
-
# End value cannot be inferred
|
|
441
|
-
with pytest.raises(ValueError, match="Must specify one of either periods or end"):
|
|
442
|
-
start = time_idx[-1]
|
|
443
|
-
ss_mod._validate_forecast_args(time_index=time_idx, start=start)
|
|
444
|
-
|
|
445
|
-
# Unnecessary args warn on verbose
|
|
446
|
-
start = time_idx[-1]
|
|
447
|
-
forecast_idx = pd.date_range(start=start, periods=10, freq="YS-JAN")
|
|
448
|
-
scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
|
|
449
|
-
|
|
450
|
-
ss_mod._validate_forecast_args(
|
|
451
|
-
time_index=time_idx, start=start, periods=10, scenario=scenario, use_scenario_index=True
|
|
452
|
-
)
|
|
453
|
-
last_message = caplog.messages[-1]
|
|
454
|
-
assert "start, end, and periods arguments are ignored" in last_message
|
|
455
|
-
|
|
456
|
-
# Verbose=False silences warning
|
|
457
|
-
ss_mod._validate_forecast_args(
|
|
458
|
-
time_index=time_idx,
|
|
459
|
-
start=start,
|
|
460
|
-
periods=10,
|
|
461
|
-
scenario=scenario,
|
|
462
|
-
use_scenario_index=True,
|
|
463
|
-
verbose=False,
|
|
464
|
-
)
|
|
465
|
-
assert len(caplog.messages) == 1
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
@pytest.mark.parametrize("use_datetime_index", [True, False])
|
|
469
|
-
def test_forecast_index(use_datetime_index):
|
|
470
|
-
ss_mod = make_statespace_mod(
|
|
471
|
-
k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
|
|
472
|
-
)
|
|
473
|
-
ss_mod._fit_coords = dict()
|
|
474
|
-
time_idx = _make_time_idx(ss_mod, use_datetime_index)
|
|
475
|
-
|
|
476
|
-
# From start and end
|
|
477
|
-
start = time_idx[-1]
|
|
478
|
-
delta = pd.DateOffset(years=10) if use_datetime_index else 11
|
|
479
|
-
end = start + delta
|
|
480
|
-
|
|
481
|
-
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
|
|
482
|
-
assert start not in forecast_idx
|
|
483
|
-
assert x0_index == start
|
|
484
|
-
assert forecast_idx.shape == (10,)
|
|
485
|
-
|
|
486
|
-
# From start and periods
|
|
487
|
-
start = time_idx[-1]
|
|
488
|
-
periods = 10
|
|
489
|
-
|
|
490
|
-
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
|
|
491
|
-
assert start not in forecast_idx
|
|
492
|
-
assert x0_index == start
|
|
493
|
-
assert forecast_idx.shape == (10,)
|
|
494
|
-
|
|
495
|
-
# From integer start
|
|
496
|
-
start = 10
|
|
497
|
-
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
|
|
498
|
-
delta = forecast_idx.freq if use_datetime_index else 1
|
|
499
|
-
|
|
500
|
-
assert x0_index == time_idx[start]
|
|
501
|
-
assert forecast_idx.shape == (10,)
|
|
502
|
-
assert (forecast_idx == time_idx[start + 1 : start + periods + 1]).all()
|
|
503
|
-
|
|
504
|
-
# From scenario index
|
|
505
|
-
scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
|
|
506
|
-
new_start, forecast_idx = ss_mod._build_forecast_index(
|
|
507
|
-
time_index=time_idx, scenario=scenario, use_scenario_index=True
|
|
508
|
-
)
|
|
509
|
-
assert x0_index not in forecast_idx
|
|
510
|
-
assert x0_index == (forecast_idx[0] - delta)
|
|
511
|
-
assert forecast_idx.shape == (10,)
|
|
512
|
-
assert forecast_idx.equals(scenario.index)
|
|
513
|
-
|
|
514
|
-
# From dictionary of scenarios
|
|
515
|
-
scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])}
|
|
516
|
-
x0_index, forecast_idx = ss_mod._build_forecast_index(
|
|
517
|
-
time_index=time_idx, scenario=scenario, use_scenario_index=True
|
|
518
|
-
)
|
|
519
|
-
assert x0_index == (forecast_idx[0] - delta)
|
|
520
|
-
assert forecast_idx.shape == (10,)
|
|
521
|
-
assert forecast_idx.equals(scenario["a"].index)
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
@pytest.mark.parametrize(
|
|
525
|
-
"data_type",
|
|
526
|
-
[pd.Series, pd.DataFrame, np.array, list, tuple],
|
|
527
|
-
ids=["series", "dataframe", "array", "list", "tuple"],
|
|
528
|
-
)
|
|
529
|
-
def test_validate_scenario(data_type):
|
|
530
|
-
if data_type is pd.DataFrame:
|
|
531
|
-
# Ensure dataframes have the correct column name
|
|
532
|
-
data_type = partial(pd.DataFrame, columns=["column_1"])
|
|
533
|
-
|
|
534
|
-
# One data case
|
|
535
|
-
data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
|
|
536
|
-
ss_mod = make_statespace_mod(
|
|
537
|
-
k_endog=1,
|
|
538
|
-
k_posdef=1,
|
|
539
|
-
k_states=2,
|
|
540
|
-
filter_type="standard",
|
|
541
|
-
verbose=False,
|
|
542
|
-
data_info=data_info,
|
|
543
|
-
)
|
|
544
|
-
ss_mod._fit_coords = dict(features_a=["column_1"])
|
|
545
|
-
|
|
546
|
-
scenario = data_type(np.zeros(10))
|
|
547
|
-
scenario = ss_mod._validate_scenario_data(scenario)
|
|
548
|
-
|
|
549
|
-
# Lists and tuples are cast to 2d arrays
|
|
550
|
-
if data_type in [tuple, list]:
|
|
551
|
-
assert isinstance(scenario, np.ndarray)
|
|
552
|
-
assert scenario.shape == (10, 1)
|
|
553
|
-
|
|
554
|
-
# A one-item dictionary should also work
|
|
555
|
-
scenario = {"a": scenario}
|
|
556
|
-
ss_mod._validate_scenario_data(scenario)
|
|
557
|
-
|
|
558
|
-
# Now data has to be a dictionary
|
|
559
|
-
data_info.update({"b": {"shape": (None, 1), "dims": ("time", "features_b")}})
|
|
560
|
-
ss_mod = make_statespace_mod(
|
|
561
|
-
k_endog=1,
|
|
562
|
-
k_posdef=1,
|
|
563
|
-
k_states=2,
|
|
564
|
-
filter_type="standard",
|
|
565
|
-
verbose=False,
|
|
566
|
-
data_info=data_info,
|
|
567
|
-
)
|
|
568
|
-
ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1"])
|
|
569
|
-
|
|
570
|
-
scenario = {"a": data_type(np.zeros(10)), "b": data_type(np.zeros(10))}
|
|
571
|
-
ss_mod._validate_scenario_data(scenario)
|
|
572
|
-
|
|
573
|
-
# Mixed data types
|
|
574
|
-
data_info.update({"a": {"shape": (None, 10), "dims": ("time", "features_a")}})
|
|
575
|
-
ss_mod = make_statespace_mod(
|
|
576
|
-
k_endog=1,
|
|
577
|
-
k_posdef=1,
|
|
578
|
-
k_states=2,
|
|
579
|
-
filter_type="standard",
|
|
580
|
-
verbose=False,
|
|
581
|
-
data_info=data_info,
|
|
582
|
-
)
|
|
583
|
-
ss_mod._fit_coords = dict(
|
|
584
|
-
features_a=[f"column_{i}" for i in range(10)], features_b=["column_1"]
|
|
585
|
-
)
|
|
586
|
-
|
|
587
|
-
scenario = {
|
|
588
|
-
"a": pd.DataFrame(np.zeros((10, 10)), columns=ss_mod._fit_coords["features_a"]),
|
|
589
|
-
"b": data_type(np.arange(10)),
|
|
590
|
-
}
|
|
591
|
-
|
|
592
|
-
ss_mod._validate_scenario_data(scenario)
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
@pytest.mark.parametrize(
|
|
596
|
-
"data_type",
|
|
597
|
-
[pd.Series, pd.DataFrame, np.array, list, tuple],
|
|
598
|
-
ids=["series", "dataframe", "array", "list", "tuple"],
|
|
599
|
-
)
|
|
600
|
-
@pytest.mark.parametrize("use_datetime_index", [True, False])
|
|
601
|
-
def test_finalize_scenario_single(data_type, use_datetime_index):
|
|
602
|
-
if data_type is pd.DataFrame:
|
|
603
|
-
# Ensure dataframes have the correct column name
|
|
604
|
-
data_type = partial(pd.DataFrame, columns=["column_1"])
|
|
605
|
-
|
|
606
|
-
data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
|
|
607
|
-
ss_mod = make_statespace_mod(
|
|
608
|
-
k_endog=1,
|
|
609
|
-
k_posdef=1,
|
|
610
|
-
k_states=2,
|
|
611
|
-
filter_type="standard",
|
|
612
|
-
verbose=False,
|
|
613
|
-
data_info=data_info,
|
|
614
|
-
)
|
|
615
|
-
ss_mod._fit_coords = dict(features_a=["column_1"])
|
|
616
|
-
|
|
617
|
-
time_idx = _make_time_idx(ss_mod, use_datetime_index)
|
|
618
|
-
|
|
619
|
-
scenario = data_type(np.zeros((10,)))
|
|
620
|
-
|
|
621
|
-
scenario = ss_mod._validate_scenario_data(scenario)
|
|
622
|
-
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
|
|
623
|
-
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
|
|
624
|
-
|
|
625
|
-
assert isinstance(scenario, pd.DataFrame)
|
|
626
|
-
assert scenario.index.equals(forecast_idx)
|
|
627
|
-
assert scenario.columns == ["column_1"]
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
@pytest.mark.parametrize(
|
|
631
|
-
"data_type",
|
|
632
|
-
[pd.Series, pd.DataFrame, np.array, list, tuple],
|
|
633
|
-
ids=["series", "dataframe", "array", "list", "tuple"],
|
|
634
|
-
)
|
|
635
|
-
@pytest.mark.parametrize("use_datetime_index", [True, False])
|
|
636
|
-
@pytest.mark.parametrize("use_scenario_index", [True, False])
|
|
637
|
-
def test_finalize_secenario_dict(data_type, use_datetime_index, use_scenario_index):
|
|
638
|
-
data_info = {
|
|
639
|
-
"a": {"shape": (None, 1), "dims": ("time", "features_a")},
|
|
640
|
-
"b": {"shape": (None, 2), "dims": ("time", "features_b")},
|
|
641
|
-
}
|
|
642
|
-
ss_mod = make_statespace_mod(
|
|
643
|
-
k_endog=1,
|
|
644
|
-
k_posdef=1,
|
|
645
|
-
k_states=2,
|
|
646
|
-
filter_type="standard",
|
|
647
|
-
verbose=False,
|
|
648
|
-
data_info=data_info,
|
|
649
|
-
)
|
|
650
|
-
ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1", "column_2"])
|
|
651
|
-
time_idx = _make_time_idx(ss_mod, use_datetime_index)
|
|
652
|
-
|
|
653
|
-
initial_index = (
|
|
654
|
-
pd.date_range(start=time_idx[-1], periods=10, freq=time_idx.freq)
|
|
655
|
-
if use_datetime_index
|
|
656
|
-
else pd.RangeIndex(time_idx[-1], time_idx[-1] + 10, 1)
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
if data_type is pd.DataFrame:
|
|
660
|
-
# Ensure dataframes have the correct column name
|
|
661
|
-
data_type = partial(pd.DataFrame, columns=["column_1"], index=initial_index)
|
|
662
|
-
elif data_type is pd.Series:
|
|
663
|
-
data_type = partial(pd.Series, index=initial_index)
|
|
664
|
-
|
|
665
|
-
scenario = {
|
|
666
|
-
"a": data_type(np.zeros((10,))),
|
|
667
|
-
"b": pd.DataFrame(
|
|
668
|
-
np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"], index=initial_index
|
|
669
|
-
),
|
|
670
|
-
}
|
|
671
|
-
|
|
672
|
-
scenario = ss_mod._validate_scenario_data(scenario)
|
|
673
|
-
|
|
674
|
-
if use_scenario_index and data_type not in [np.array, list, tuple]:
|
|
675
|
-
t0, forecast_idx = ss_mod._build_forecast_index(
|
|
676
|
-
time_idx, scenario=scenario, periods=10, use_scenario_index=True
|
|
677
|
-
)
|
|
678
|
-
elif use_scenario_index and data_type in [np.array, list, tuple]:
|
|
679
|
-
t0, forecast_idx = ss_mod._build_forecast_index(
|
|
680
|
-
time_idx, scenario=scenario, start=-1, periods=10, use_scenario_index=True
|
|
681
|
-
)
|
|
682
|
-
else:
|
|
683
|
-
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
|
|
684
|
-
|
|
685
|
-
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
|
|
686
|
-
|
|
687
|
-
assert list(scenario.keys()) == ["a", "b"]
|
|
688
|
-
assert all(isinstance(value, pd.DataFrame) for value in scenario.values())
|
|
689
|
-
assert all(value.index.equals(forecast_idx) for value in scenario.values())
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
def test_invalid_scenarios():
|
|
693
|
-
data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
|
|
694
|
-
ss_mod = make_statespace_mod(
|
|
695
|
-
k_endog=1,
|
|
696
|
-
k_posdef=1,
|
|
697
|
-
k_states=2,
|
|
698
|
-
filter_type="standard",
|
|
699
|
-
verbose=False,
|
|
700
|
-
data_info=data_info,
|
|
701
|
-
)
|
|
702
|
-
ss_mod._fit_coords = dict(features_a=["column_1", "column_2"])
|
|
703
|
-
|
|
704
|
-
# Omitting the data raises
|
|
705
|
-
with pytest.raises(
|
|
706
|
-
ValueError, match="This model was fit using exogenous data. Forecasting cannot be performed"
|
|
707
|
-
):
|
|
708
|
-
ss_mod._validate_scenario_data(None)
|
|
709
|
-
|
|
710
|
-
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
|
|
711
|
-
with pytest.raises(
|
|
712
|
-
ValueError,
|
|
713
|
-
match="Scenario data for variable 'a' has the wrong number of columns. "
|
|
714
|
-
"Expected 2, got 1",
|
|
715
|
-
):
|
|
716
|
-
for data_type in [list, tuple, pd.Series]:
|
|
717
|
-
ss_mod._validate_scenario_data(data_type(np.zeros(10)))
|
|
718
|
-
ss_mod._validate_scenario_data({"a": data_type(np.zeros(10))})
|
|
719
|
-
|
|
720
|
-
# Providing irrevelant data raises
|
|
721
|
-
with pytest.raises(
|
|
722
|
-
ValueError,
|
|
723
|
-
match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
|
|
724
|
-
):
|
|
725
|
-
ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
|
|
726
|
-
|
|
727
|
-
# Incorrect 2nd dimension of a non-dataframe
|
|
728
|
-
with pytest.raises(
|
|
729
|
-
ValueError,
|
|
730
|
-
match="Scenario data for variable 'a' has the wrong number of columns. Expected "
|
|
731
|
-
"2, got 1",
|
|
732
|
-
):
|
|
733
|
-
scenario = np.zeros(10).tolist()
|
|
734
|
-
ss_mod._validate_scenario_data(scenario)
|
|
735
|
-
ss_mod._validate_scenario_data(tuple(scenario))
|
|
736
|
-
|
|
737
|
-
scenario = {"a": np.zeros(10).tolist()}
|
|
738
|
-
ss_mod._validate_scenario_data(scenario)
|
|
739
|
-
ss_mod._validate_scenario_data({"a": tuple(scenario["a"])})
|
|
740
|
-
|
|
741
|
-
# If a data frame is provided, it needs to have all columns
|
|
742
|
-
with pytest.raises(
|
|
743
|
-
ValueError, match="Scenario data for variable 'a' is missing the following column: column_2"
|
|
744
|
-
):
|
|
745
|
-
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["column_1"])
|
|
746
|
-
ss_mod._validate_scenario_data(scenario)
|
|
747
|
-
|
|
748
|
-
# Extra columns also raises
|
|
749
|
-
with pytest.raises(
|
|
750
|
-
ValueError,
|
|
751
|
-
match="Scenario data for variable 'a' contains the following extra columns "
|
|
752
|
-
"that are not used by the model: column_3, column_4",
|
|
753
|
-
):
|
|
754
|
-
scenario = pd.DataFrame(
|
|
755
|
-
np.zeros((10, 4)), columns=["column_1", "column_2", "column_3", "column_4"]
|
|
756
|
-
)
|
|
757
|
-
ss_mod._validate_scenario_data(scenario)
|
|
758
|
-
|
|
759
|
-
# Wrong number of time steps raises
|
|
760
|
-
data_info = {
|
|
761
|
-
"a": {"shape": (None, 1), "dims": ("time", "features_a")},
|
|
762
|
-
"b": {"shape": (None, 1), "dims": ("time", "features_b")},
|
|
763
|
-
}
|
|
764
|
-
ss_mod = make_statespace_mod(
|
|
765
|
-
k_endog=1,
|
|
766
|
-
k_posdef=1,
|
|
767
|
-
k_states=2,
|
|
768
|
-
filter_type="standard",
|
|
769
|
-
verbose=False,
|
|
770
|
-
data_info=data_info,
|
|
771
|
-
)
|
|
772
|
-
ss_mod._fit_coords = dict(
|
|
773
|
-
features_a=["column_1", "column_2"], features_b=["column_1", "column_2"]
|
|
774
|
-
)
|
|
775
|
-
|
|
776
|
-
with pytest.raises(
|
|
777
|
-
ValueError, match="Scenario data must have the same number of time steps for all variables"
|
|
778
|
-
):
|
|
779
|
-
scenario = {
|
|
780
|
-
"a": pd.DataFrame(np.zeros((10, 2)), columns=ss_mod._fit_coords["features_a"]),
|
|
781
|
-
"b": pd.DataFrame(np.zeros((11, 2)), columns=ss_mod._fit_coords["features_b"]),
|
|
782
|
-
}
|
|
783
|
-
ss_mod._validate_scenario_data(scenario)
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
|
|
787
|
-
@pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
|
|
788
|
-
@pytest.mark.parametrize(
|
|
789
|
-
"mod_name, idata_name, start, end, periods",
|
|
790
|
-
[
|
|
791
|
-
("ss_mod_no_exog", "idata_no_exog", None, None, 10),
|
|
792
|
-
("ss_mod_no_exog", "idata_no_exog", -1, None, 10),
|
|
793
|
-
("ss_mod_no_exog", "idata_no_exog", 10, None, 10),
|
|
794
|
-
("ss_mod_no_exog", "idata_no_exog", 10, 21, None),
|
|
795
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", None, None, 10),
|
|
796
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", -1, None, 10),
|
|
797
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, None, 10),
|
|
798
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, "2020-01-21", None),
|
|
799
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", "2020-03-11", None),
|
|
800
|
-
("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", None, 10),
|
|
801
|
-
],
|
|
802
|
-
ids=[
|
|
803
|
-
"range_default",
|
|
804
|
-
"range_negative",
|
|
805
|
-
"range_int",
|
|
806
|
-
"range_end",
|
|
807
|
-
"datetime_default",
|
|
808
|
-
"datetime_negative",
|
|
809
|
-
"datetime_int",
|
|
810
|
-
"datetime_int_end",
|
|
811
|
-
"datetime_datetime_end",
|
|
812
|
-
"datetime_datetime",
|
|
813
|
-
],
|
|
814
|
-
)
|
|
815
|
-
def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng, request):
|
|
816
|
-
mod = request.getfixturevalue(mod_name)
|
|
817
|
-
idata = request.getfixturevalue(idata_name)
|
|
818
|
-
time_idx = mod._get_fit_time_index()
|
|
819
|
-
is_datetime = isinstance(time_idx, pd.DatetimeIndex)
|
|
820
|
-
|
|
821
|
-
if isinstance(start, str):
|
|
822
|
-
t0 = pd.Timestamp(start)
|
|
823
|
-
elif isinstance(start, int):
|
|
824
|
-
t0 = time_idx[start]
|
|
825
|
-
else:
|
|
826
|
-
t0 = time_idx[-1]
|
|
827
|
-
|
|
828
|
-
delta = time_idx.freq if is_datetime else 1
|
|
829
|
-
|
|
830
|
-
forecast_idata = mod.forecast(
|
|
831
|
-
idata, start=start, end=end, periods=periods, filter_output=filter_output, random_seed=rng
|
|
832
|
-
)
|
|
833
|
-
|
|
834
|
-
forecast_idx = forecast_idata.coords["time"].values
|
|
835
|
-
forecast_idx = pd.DatetimeIndex(forecast_idx) if is_datetime else pd.Index(forecast_idx)
|
|
836
|
-
|
|
837
|
-
assert forecast_idx.shape == (10,)
|
|
838
|
-
assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
|
|
839
|
-
assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
|
|
840
|
-
|
|
841
|
-
assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
|
|
842
|
-
assert not np.any(np.isnan(forecast_idata.forecast_observed.values))
|
|
843
|
-
|
|
844
|
-
assert forecast_idx[0] == (t0 + delta)
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
|
|
848
|
-
@pytest.mark.parametrize("start", [None, -1, 10])
|
|
849
|
-
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
|
|
850
|
-
scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
|
|
851
|
-
scenario.iloc[5, 0] = 1e9
|
|
852
|
-
|
|
853
|
-
forecast_idata = exog_ss_mod.forecast(
|
|
854
|
-
idata_exog, start=start, periods=10, random_seed=rng, scenario=scenario
|
|
855
|
-
)
|
|
856
|
-
|
|
857
|
-
components = exog_ss_mod.extract_components_from_idata(forecast_idata)
|
|
858
|
-
level = components.forecast_latent.sel(state="LevelTrend[level]")
|
|
859
|
-
betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
|
|
860
|
-
|
|
861
|
-
scenario.index.name = "time"
|
|
862
|
-
scenario_xr = (
|
|
863
|
-
scenario.unstack()
|
|
864
|
-
.to_xarray()
|
|
865
|
-
.rename({"level_0": "state"})
|
|
866
|
-
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
|
|
867
|
-
)
|
|
868
|
-
|
|
869
|
-
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
|
|
870
|
-
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
|
|
871
|
-
|
|
872
|
-
assert_allclose(regression_effect, regression_effect_expected)
|