pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,872 +0,0 @@
1
- from collections.abc import Sequence
2
- from functools import partial
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import pymc as pm
7
- import pytensor
8
- import pytensor.tensor as pt
9
- import pytest
10
-
11
- from numpy.testing import assert_allclose
12
-
13
- from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
14
- from pymc_extras.statespace.models import structural as st
15
- from pymc_extras.statespace.models.utilities import make_default_coords
16
- from pymc_extras.statespace.utils.constants import (
17
- FILTER_OUTPUT_NAMES,
18
- MATRIX_NAMES,
19
- SMOOTHER_OUTPUT_NAMES,
20
- )
21
- from tests.statespace.utilities.shared_fixtures import (
22
- rng,
23
- )
24
- from tests.statespace.utilities.test_helpers import (
25
- fast_eval,
26
- load_nile_test_data,
27
- make_test_inputs,
28
- )
29
-
30
- floatX = pytensor.config.floatX
31
- nile = load_nile_test_data()
32
- ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
33
-
34
-
35
- def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
36
- class StateSpace(PyMCStateSpace):
37
- def make_symbolic_graph(self):
38
- pass
39
-
40
- @property
41
- def data_info(self):
42
- return data_info
43
-
44
- ss = StateSpace(
45
- k_states=k_states,
46
- k_endog=k_endog,
47
- k_posdef=k_posdef,
48
- filter_type=filter_type,
49
- verbose=verbose,
50
- )
51
- ss._needs_exog_data = data_info is not None
52
- ss._exog_names = list(data_info.keys()) if data_info is not None else []
53
-
54
- return ss
55
-
56
-
57
- @pytest.fixture(scope="session")
58
- def ss_mod():
59
- class StateSpace(PyMCStateSpace):
60
- @property
61
- def param_names(self):
62
- return ["rho", "zeta"]
63
-
64
- @property
65
- def state_names(self):
66
- return ["a", "b"]
67
-
68
- @property
69
- def observed_states(self):
70
- return ["a"]
71
-
72
- @property
73
- def shock_names(self):
74
- return ["a"]
75
-
76
- @property
77
- def coords(self):
78
- return make_default_coords(self)
79
-
80
- def make_symbolic_graph(self):
81
- rho = self.make_and_register_variable("rho", ())
82
- zeta = self.make_and_register_variable("zeta", ())
83
- self.ssm["transition", 0, 0] = rho
84
- self.ssm["transition", 1, 0] = zeta
85
-
86
- Z = np.array([[1.0, 0.0]], dtype=floatX)
87
- R = np.array([[1.0], [0.0]], dtype=floatX)
88
- H = np.array([[0.1]], dtype=floatX)
89
- Q = np.array([[0.8]], dtype=floatX)
90
- P0 = np.eye(2, dtype=floatX) * 1e6
91
-
92
- ss_mod = StateSpace(
93
- k_endog=nile.shape[1], k_states=2, k_posdef=1, filter_type="standard", verbose=False
94
- )
95
- for X, name in zip(
96
- [Z, R, H, Q, P0],
97
- ["design", "selection", "obs_cov", "state_cov", "initial_state_cov"],
98
- ):
99
- ss_mod.ssm[name] = X
100
-
101
- return ss_mod
102
-
103
-
104
- @pytest.fixture(scope="session")
105
- def pymc_mod(ss_mod):
106
- with pm.Model(coords=ss_mod.coords) as pymc_mod:
107
- rho = pm.Beta("rho", 1, 1)
108
- zeta = pm.Deterministic("zeta", 1 - rho)
109
-
110
- ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True)
111
- names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
112
- for name, matrix in zip(names, ss_mod.unpack_statespace()):
113
- pm.Deterministic(name, matrix)
114
-
115
- return pymc_mod
116
-
117
-
118
- @pytest.fixture(scope="session")
119
- def ss_mod_no_exog(rng):
120
- ll = st.LevelTrendComponent(order=2, innovations_order=1)
121
- return ll.build()
122
-
123
-
124
- @pytest.fixture(scope="session")
125
- def ss_mod_no_exog_dt(rng):
126
- ll = st.LevelTrendComponent(order=2, innovations_order=1)
127
- return ll.build()
128
-
129
-
130
- @pytest.fixture(scope="session")
131
- def exog_ss_mod(rng):
132
- ll = st.LevelTrendComponent()
133
- reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
134
- mod = (ll + reg).build(verbose=False)
135
-
136
- return mod
137
-
138
-
139
- @pytest.fixture(scope="session")
140
- def exog_pymc_mod(exog_ss_mod, rng):
141
- y = rng.normal(size=(100, 1)).astype(floatX)
142
- X = rng.normal(size=(100, 3)).astype(floatX)
143
-
144
- with pm.Model(coords=exog_ss_mod.coords) as m:
145
- exog_data = pm.Data("data_exog", X)
146
- initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
147
- P0_sigma = pm.Exponential("P0_sigma", 1)
148
- P0 = pm.Deterministic(
149
- "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
150
- )
151
- beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
152
-
153
- sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
154
- exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
155
-
156
- return m
157
-
158
-
159
- @pytest.fixture(scope="session")
160
- def pymc_mod_no_exog(ss_mod_no_exog, rng):
161
- y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
162
-
163
- with pm.Model(coords=ss_mod_no_exog.coords) as m:
164
- initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
165
- P0_sigma = pm.Exponential("P0_sigma", 1)
166
- P0 = pm.Deterministic(
167
- "P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
168
- )
169
- sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
170
- ss_mod_no_exog.build_statespace_graph(y)
171
-
172
- return m
173
-
174
-
175
- @pytest.fixture(scope="session")
176
- def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
177
- y = pd.DataFrame(
178
- rng.normal(size=(100, 1)).astype(floatX),
179
- columns=["y"],
180
- index=pd.date_range("2020-01-01", periods=100, freq="D"),
181
- )
182
-
183
- with pm.Model(coords=ss_mod_no_exog_dt.coords) as m:
184
- initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
185
- P0_sigma = pm.Exponential("P0_sigma", 1)
186
- P0 = pm.Deterministic(
187
- "P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
188
- )
189
- sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
190
- ss_mod_no_exog_dt.build_statespace_graph(y)
191
-
192
- return m
193
-
194
-
195
- @pytest.fixture(scope="session")
196
- def idata(pymc_mod, rng):
197
- with pymc_mod:
198
- idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
199
- idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
200
-
201
- idata.extend(idata_prior)
202
- return idata
203
-
204
-
205
- @pytest.fixture(scope="session")
206
- def idata_exog(exog_pymc_mod, rng):
207
- with exog_pymc_mod:
208
- idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
209
- idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
210
- idata.extend(idata_prior)
211
- return idata
212
-
213
-
214
- @pytest.fixture(scope="session")
215
- def idata_no_exog(pymc_mod_no_exog, rng):
216
- with pymc_mod_no_exog:
217
- idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
218
- idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
219
- idata.extend(idata_prior)
220
- return idata
221
-
222
-
223
- @pytest.fixture(scope="session")
224
- def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
225
- with pymc_mod_no_exog_dt:
226
- idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
227
- idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
228
- idata.extend(idata_prior)
229
- return idata
230
-
231
-
232
- def test_invalid_filter_name_raises():
233
- msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys()))
234
- with pytest.raises(NotImplementedError, match=msg):
235
- mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter")
236
-
237
-
238
- def test_unpack_before_insert_raises(rng):
239
- p, m, r, n = 2, 5, 1, 10
240
- data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
241
- mod = make_statespace_mod(
242
- k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
243
- )
244
-
245
- msg = "Cannot unpack the complete statespace system until PyMC model variables have been inserted."
246
- with pytest.raises(ValueError, match=msg):
247
- outputs = mod.unpack_statespace()
248
-
249
-
250
- def test_unpack_matrices(rng):
251
- p, m, r, n = 2, 5, 1, 10
252
- data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
253
- mod = make_statespace_mod(
254
- k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
255
- )
256
-
257
- # mod is a dummy statespace, so there are no placeholders to worry about. Monkey patch subbed_ssm with the defaults
258
- mod.subbed_ssm = mod._unpack_statespace_with_placeholders()
259
-
260
- outputs = mod.unpack_statespace()
261
- for x, y in zip(inputs, outputs):
262
- assert_allclose(np.zeros_like(x), fast_eval(y))
263
-
264
-
265
- def test_param_names_raises_on_base_class():
266
- mod = make_statespace_mod(
267
- k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
268
- )
269
- with pytest.raises(NotImplementedError):
270
- x = mod.param_names
271
-
272
-
273
- def test_base_class_raises():
274
- with pytest.raises(NotImplementedError):
275
- mod = PyMCStateSpace(
276
- k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
277
- )
278
-
279
-
280
- def test_update_raises_if_missing_variables(ss_mod):
281
- with pm.Model() as mod:
282
- rho = pm.Normal("rho")
283
- msg = "The following required model parameters were not found in the PyMC model: zeta"
284
- with pytest.raises(ValueError, match=msg):
285
- ss_mod._insert_random_variables()
286
-
287
-
288
- def test_build_statespace_graph_warns_if_data_has_nans():
289
- # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
290
- ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
291
-
292
- with pm.Model() as pymc_mod:
293
- initial_trend = pm.Normal("initial_trend", shape=(1,))
294
- P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
295
- with pytest.warns(pm.ImputationWarning):
296
- ss_mod.build_statespace_graph(
297
- data=np.full((10, 1), np.nan, dtype=floatX), register_data=False
298
- )
299
-
300
-
301
- def test_build_statespace_graph_raises_if_data_has_missing_fill():
302
- # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
303
- ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
304
-
305
- with pm.Model() as pymc_mod:
306
- initial_trend = pm.Normal("initial_trend", shape=(1,))
307
- P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
308
- with pytest.raises(ValueError, match="Provided data contains the value 1.0"):
309
- data = np.ones((10, 1), dtype=floatX)
310
- data[3] = np.nan
311
- ss_mod.build_statespace_graph(data=data, missing_fill_value=1.0, register_data=False)
312
-
313
-
314
- def test_build_statespace_graph(pymc_mod):
315
- for name in [
316
- "filtered_state",
317
- "predicted_state",
318
- "predicted_covariance",
319
- "filtered_covariance",
320
- ]:
321
- assert name in [x.name for x in pymc_mod.deterministics]
322
-
323
-
324
- def test_build_smoother_graph(ss_mod, pymc_mod):
325
- names = ["smoothed_state", "smoothed_covariance"]
326
- for name in names:
327
- assert name in [x.name for x in pymc_mod.deterministics]
328
-
329
-
330
- @pytest.mark.parametrize("group", ["posterior", "prior"])
331
- @pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
332
- def test_no_nans_in_sampling_output(group, matrix, idata):
333
- assert not np.any(np.isnan(idata[group][matrix].values))
334
-
335
-
336
- @pytest.mark.parametrize("group", ["posterior", "prior"])
337
- @pytest.mark.parametrize("kind", ["conditional", "unconditional"])
338
- def test_sampling_methods(group, kind, ss_mod, idata, rng):
339
- f = getattr(ss_mod, f"sample_{kind}_{group}")
340
- test_idata = f(idata, random_seed=rng)
341
-
342
- if kind == "conditional":
343
- for output in ["filtered", "predicted", "smoothed"]:
344
- assert f"{output}_{group}" in test_idata
345
- assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
346
- assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
347
- if kind == "unconditional":
348
- for output in ["latent", "observed"]:
349
- assert f"{group}_{output}" in test_idata
350
- assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
351
-
352
-
353
- @pytest.mark.filterwarnings("ignore:Provided data contains missing values")
354
- def test_sample_conditional_with_time_varying():
355
- class TVCovariance(PyMCStateSpace):
356
- def __init__(self):
357
- super().__init__(k_states=1, k_endog=1, k_posdef=1)
358
-
359
- def make_symbolic_graph(self) -> None:
360
- self.ssm["transition", 0, 0] = 1.0
361
-
362
- self.ssm["design", 0, 0] = 1.0
363
-
364
- sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
365
- self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
366
-
367
- @property
368
- def param_names(self) -> list[str]:
369
- return ["sigma_cov"]
370
-
371
- @property
372
- def coords(self) -> dict[str, Sequence[str]]:
373
- return make_default_coords(self)
374
-
375
- @property
376
- def state_names(self) -> list[str]:
377
- return ["level"]
378
-
379
- @property
380
- def observed_states(self) -> list[str]:
381
- return ["level"]
382
-
383
- @property
384
- def shock_names(self) -> list[str]:
385
- return ["level"]
386
-
387
- ss_mod = TVCovariance()
388
- empty_data = pd.DataFrame(
389
- np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
390
- )
391
-
392
- coords = ss_mod.coords
393
- coords["time"] = empty_data.index
394
- with pm.Model(coords=coords) as mod:
395
- log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
396
- pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
397
-
398
- ss_mod.build_statespace_graph(data=empty_data)
399
-
400
- prior = pm.sample_prior_predictive(10)
401
-
402
- ss_mod.sample_unconditional_prior(prior)
403
- ss_mod.sample_conditional_prior(prior)
404
-
405
-
406
- def _make_time_idx(mod, use_datetime_index=True):
407
- if use_datetime_index:
408
- mod._fit_coords["time"] = nile.index
409
- time_idx = nile.index
410
- else:
411
- mod._fit_coords["time"] = nile.reset_index().index
412
- time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1)
413
-
414
- return time_idx
415
-
416
-
417
- @pytest.mark.parametrize("use_datetime_index", [True, False])
418
- def test_bad_forecast_arguments(use_datetime_index, caplog):
419
- ss_mod = make_statespace_mod(
420
- k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
421
- )
422
-
423
- # Not-fit model raises
424
- ss_mod._fit_coords = dict()
425
- with pytest.raises(ValueError, match="Has this model been fit?"):
426
- ss_mod._get_fit_time_index()
427
-
428
- time_idx = _make_time_idx(ss_mod, use_datetime_index)
429
-
430
- # Start value not in time index
431
- match = (
432
- "Datetime start must be in the data index used to fit the model"
433
- if use_datetime_index
434
- else "Integer start must be within the range of the data index used to fit the model."
435
- )
436
- with pytest.raises(ValueError, match=match):
437
- start = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11
438
- ss_mod._validate_forecast_args(time_index=time_idx, start=start, periods=10)
439
-
440
- # End value cannot be inferred
441
- with pytest.raises(ValueError, match="Must specify one of either periods or end"):
442
- start = time_idx[-1]
443
- ss_mod._validate_forecast_args(time_index=time_idx, start=start)
444
-
445
- # Unnecessary args warn on verbose
446
- start = time_idx[-1]
447
- forecast_idx = pd.date_range(start=start, periods=10, freq="YS-JAN")
448
- scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
449
-
450
- ss_mod._validate_forecast_args(
451
- time_index=time_idx, start=start, periods=10, scenario=scenario, use_scenario_index=True
452
- )
453
- last_message = caplog.messages[-1]
454
- assert "start, end, and periods arguments are ignored" in last_message
455
-
456
- # Verbose=False silences warning
457
- ss_mod._validate_forecast_args(
458
- time_index=time_idx,
459
- start=start,
460
- periods=10,
461
- scenario=scenario,
462
- use_scenario_index=True,
463
- verbose=False,
464
- )
465
- assert len(caplog.messages) == 1
466
-
467
-
468
- @pytest.mark.parametrize("use_datetime_index", [True, False])
469
- def test_forecast_index(use_datetime_index):
470
- ss_mod = make_statespace_mod(
471
- k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
472
- )
473
- ss_mod._fit_coords = dict()
474
- time_idx = _make_time_idx(ss_mod, use_datetime_index)
475
-
476
- # From start and end
477
- start = time_idx[-1]
478
- delta = pd.DateOffset(years=10) if use_datetime_index else 11
479
- end = start + delta
480
-
481
- x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
482
- assert start not in forecast_idx
483
- assert x0_index == start
484
- assert forecast_idx.shape == (10,)
485
-
486
- # From start and periods
487
- start = time_idx[-1]
488
- periods = 10
489
-
490
- x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
491
- assert start not in forecast_idx
492
- assert x0_index == start
493
- assert forecast_idx.shape == (10,)
494
-
495
- # From integer start
496
- start = 10
497
- x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
498
- delta = forecast_idx.freq if use_datetime_index else 1
499
-
500
- assert x0_index == time_idx[start]
501
- assert forecast_idx.shape == (10,)
502
- assert (forecast_idx == time_idx[start + 1 : start + periods + 1]).all()
503
-
504
- # From scenario index
505
- scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
506
- new_start, forecast_idx = ss_mod._build_forecast_index(
507
- time_index=time_idx, scenario=scenario, use_scenario_index=True
508
- )
509
- assert x0_index not in forecast_idx
510
- assert x0_index == (forecast_idx[0] - delta)
511
- assert forecast_idx.shape == (10,)
512
- assert forecast_idx.equals(scenario.index)
513
-
514
- # From dictionary of scenarios
515
- scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])}
516
- x0_index, forecast_idx = ss_mod._build_forecast_index(
517
- time_index=time_idx, scenario=scenario, use_scenario_index=True
518
- )
519
- assert x0_index == (forecast_idx[0] - delta)
520
- assert forecast_idx.shape == (10,)
521
- assert forecast_idx.equals(scenario["a"].index)
522
-
523
-
524
- @pytest.mark.parametrize(
525
- "data_type",
526
- [pd.Series, pd.DataFrame, np.array, list, tuple],
527
- ids=["series", "dataframe", "array", "list", "tuple"],
528
- )
529
- def test_validate_scenario(data_type):
530
- if data_type is pd.DataFrame:
531
- # Ensure dataframes have the correct column name
532
- data_type = partial(pd.DataFrame, columns=["column_1"])
533
-
534
- # One data case
535
- data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
536
- ss_mod = make_statespace_mod(
537
- k_endog=1,
538
- k_posdef=1,
539
- k_states=2,
540
- filter_type="standard",
541
- verbose=False,
542
- data_info=data_info,
543
- )
544
- ss_mod._fit_coords = dict(features_a=["column_1"])
545
-
546
- scenario = data_type(np.zeros(10))
547
- scenario = ss_mod._validate_scenario_data(scenario)
548
-
549
- # Lists and tuples are cast to 2d arrays
550
- if data_type in [tuple, list]:
551
- assert isinstance(scenario, np.ndarray)
552
- assert scenario.shape == (10, 1)
553
-
554
- # A one-item dictionary should also work
555
- scenario = {"a": scenario}
556
- ss_mod._validate_scenario_data(scenario)
557
-
558
- # Now data has to be a dictionary
559
- data_info.update({"b": {"shape": (None, 1), "dims": ("time", "features_b")}})
560
- ss_mod = make_statespace_mod(
561
- k_endog=1,
562
- k_posdef=1,
563
- k_states=2,
564
- filter_type="standard",
565
- verbose=False,
566
- data_info=data_info,
567
- )
568
- ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1"])
569
-
570
- scenario = {"a": data_type(np.zeros(10)), "b": data_type(np.zeros(10))}
571
- ss_mod._validate_scenario_data(scenario)
572
-
573
- # Mixed data types
574
- data_info.update({"a": {"shape": (None, 10), "dims": ("time", "features_a")}})
575
- ss_mod = make_statespace_mod(
576
- k_endog=1,
577
- k_posdef=1,
578
- k_states=2,
579
- filter_type="standard",
580
- verbose=False,
581
- data_info=data_info,
582
- )
583
- ss_mod._fit_coords = dict(
584
- features_a=[f"column_{i}" for i in range(10)], features_b=["column_1"]
585
- )
586
-
587
- scenario = {
588
- "a": pd.DataFrame(np.zeros((10, 10)), columns=ss_mod._fit_coords["features_a"]),
589
- "b": data_type(np.arange(10)),
590
- }
591
-
592
- ss_mod._validate_scenario_data(scenario)
593
-
594
-
595
- @pytest.mark.parametrize(
596
- "data_type",
597
- [pd.Series, pd.DataFrame, np.array, list, tuple],
598
- ids=["series", "dataframe", "array", "list", "tuple"],
599
- )
600
- @pytest.mark.parametrize("use_datetime_index", [True, False])
601
- def test_finalize_scenario_single(data_type, use_datetime_index):
602
- if data_type is pd.DataFrame:
603
- # Ensure dataframes have the correct column name
604
- data_type = partial(pd.DataFrame, columns=["column_1"])
605
-
606
- data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
607
- ss_mod = make_statespace_mod(
608
- k_endog=1,
609
- k_posdef=1,
610
- k_states=2,
611
- filter_type="standard",
612
- verbose=False,
613
- data_info=data_info,
614
- )
615
- ss_mod._fit_coords = dict(features_a=["column_1"])
616
-
617
- time_idx = _make_time_idx(ss_mod, use_datetime_index)
618
-
619
- scenario = data_type(np.zeros((10,)))
620
-
621
- scenario = ss_mod._validate_scenario_data(scenario)
622
- t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
623
- scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
624
-
625
- assert isinstance(scenario, pd.DataFrame)
626
- assert scenario.index.equals(forecast_idx)
627
- assert scenario.columns == ["column_1"]
628
-
629
-
630
- @pytest.mark.parametrize(
631
- "data_type",
632
- [pd.Series, pd.DataFrame, np.array, list, tuple],
633
- ids=["series", "dataframe", "array", "list", "tuple"],
634
- )
635
- @pytest.mark.parametrize("use_datetime_index", [True, False])
636
- @pytest.mark.parametrize("use_scenario_index", [True, False])
637
- def test_finalize_secenario_dict(data_type, use_datetime_index, use_scenario_index):
638
- data_info = {
639
- "a": {"shape": (None, 1), "dims": ("time", "features_a")},
640
- "b": {"shape": (None, 2), "dims": ("time", "features_b")},
641
- }
642
- ss_mod = make_statespace_mod(
643
- k_endog=1,
644
- k_posdef=1,
645
- k_states=2,
646
- filter_type="standard",
647
- verbose=False,
648
- data_info=data_info,
649
- )
650
- ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1", "column_2"])
651
- time_idx = _make_time_idx(ss_mod, use_datetime_index)
652
-
653
- initial_index = (
654
- pd.date_range(start=time_idx[-1], periods=10, freq=time_idx.freq)
655
- if use_datetime_index
656
- else pd.RangeIndex(time_idx[-1], time_idx[-1] + 10, 1)
657
- )
658
-
659
- if data_type is pd.DataFrame:
660
- # Ensure dataframes have the correct column name
661
- data_type = partial(pd.DataFrame, columns=["column_1"], index=initial_index)
662
- elif data_type is pd.Series:
663
- data_type = partial(pd.Series, index=initial_index)
664
-
665
- scenario = {
666
- "a": data_type(np.zeros((10,))),
667
- "b": pd.DataFrame(
668
- np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"], index=initial_index
669
- ),
670
- }
671
-
672
- scenario = ss_mod._validate_scenario_data(scenario)
673
-
674
- if use_scenario_index and data_type not in [np.array, list, tuple]:
675
- t0, forecast_idx = ss_mod._build_forecast_index(
676
- time_idx, scenario=scenario, periods=10, use_scenario_index=True
677
- )
678
- elif use_scenario_index and data_type in [np.array, list, tuple]:
679
- t0, forecast_idx = ss_mod._build_forecast_index(
680
- time_idx, scenario=scenario, start=-1, periods=10, use_scenario_index=True
681
- )
682
- else:
683
- t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
684
-
685
- scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
686
-
687
- assert list(scenario.keys()) == ["a", "b"]
688
- assert all(isinstance(value, pd.DataFrame) for value in scenario.values())
689
- assert all(value.index.equals(forecast_idx) for value in scenario.values())
690
-
691
-
692
- def test_invalid_scenarios():
693
- data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
694
- ss_mod = make_statespace_mod(
695
- k_endog=1,
696
- k_posdef=1,
697
- k_states=2,
698
- filter_type="standard",
699
- verbose=False,
700
- data_info=data_info,
701
- )
702
- ss_mod._fit_coords = dict(features_a=["column_1", "column_2"])
703
-
704
- # Omitting the data raises
705
- with pytest.raises(
706
- ValueError, match="This model was fit using exogenous data. Forecasting cannot be performed"
707
- ):
708
- ss_mod._validate_scenario_data(None)
709
-
710
- # Giving a list, tuple, or Series when a matrix of data is expected should always raise
711
- with pytest.raises(
712
- ValueError,
713
- match="Scenario data for variable 'a' has the wrong number of columns. "
714
- "Expected 2, got 1",
715
- ):
716
- for data_type in [list, tuple, pd.Series]:
717
- ss_mod._validate_scenario_data(data_type(np.zeros(10)))
718
- ss_mod._validate_scenario_data({"a": data_type(np.zeros(10))})
719
-
720
- # Providing irrevelant data raises
721
- with pytest.raises(
722
- ValueError,
723
- match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
724
- ):
725
- ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
726
-
727
- # Incorrect 2nd dimension of a non-dataframe
728
- with pytest.raises(
729
- ValueError,
730
- match="Scenario data for variable 'a' has the wrong number of columns. Expected "
731
- "2, got 1",
732
- ):
733
- scenario = np.zeros(10).tolist()
734
- ss_mod._validate_scenario_data(scenario)
735
- ss_mod._validate_scenario_data(tuple(scenario))
736
-
737
- scenario = {"a": np.zeros(10).tolist()}
738
- ss_mod._validate_scenario_data(scenario)
739
- ss_mod._validate_scenario_data({"a": tuple(scenario["a"])})
740
-
741
- # If a data frame is provided, it needs to have all columns
742
- with pytest.raises(
743
- ValueError, match="Scenario data for variable 'a' is missing the following column: column_2"
744
- ):
745
- scenario = pd.DataFrame(np.zeros((10, 1)), columns=["column_1"])
746
- ss_mod._validate_scenario_data(scenario)
747
-
748
- # Extra columns also raises
749
- with pytest.raises(
750
- ValueError,
751
- match="Scenario data for variable 'a' contains the following extra columns "
752
- "that are not used by the model: column_3, column_4",
753
- ):
754
- scenario = pd.DataFrame(
755
- np.zeros((10, 4)), columns=["column_1", "column_2", "column_3", "column_4"]
756
- )
757
- ss_mod._validate_scenario_data(scenario)
758
-
759
- # Wrong number of time steps raises
760
- data_info = {
761
- "a": {"shape": (None, 1), "dims": ("time", "features_a")},
762
- "b": {"shape": (None, 1), "dims": ("time", "features_b")},
763
- }
764
- ss_mod = make_statespace_mod(
765
- k_endog=1,
766
- k_posdef=1,
767
- k_states=2,
768
- filter_type="standard",
769
- verbose=False,
770
- data_info=data_info,
771
- )
772
- ss_mod._fit_coords = dict(
773
- features_a=["column_1", "column_2"], features_b=["column_1", "column_2"]
774
- )
775
-
776
- with pytest.raises(
777
- ValueError, match="Scenario data must have the same number of time steps for all variables"
778
- ):
779
- scenario = {
780
- "a": pd.DataFrame(np.zeros((10, 2)), columns=ss_mod._fit_coords["features_a"]),
781
- "b": pd.DataFrame(np.zeros((11, 2)), columns=ss_mod._fit_coords["features_b"]),
782
- }
783
- ss_mod._validate_scenario_data(scenario)
784
-
785
-
786
- @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
787
- @pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
788
- @pytest.mark.parametrize(
789
- "mod_name, idata_name, start, end, periods",
790
- [
791
- ("ss_mod_no_exog", "idata_no_exog", None, None, 10),
792
- ("ss_mod_no_exog", "idata_no_exog", -1, None, 10),
793
- ("ss_mod_no_exog", "idata_no_exog", 10, None, 10),
794
- ("ss_mod_no_exog", "idata_no_exog", 10, 21, None),
795
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", None, None, 10),
796
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", -1, None, 10),
797
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, None, 10),
798
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, "2020-01-21", None),
799
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", "2020-03-11", None),
800
- ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", None, 10),
801
- ],
802
- ids=[
803
- "range_default",
804
- "range_negative",
805
- "range_int",
806
- "range_end",
807
- "datetime_default",
808
- "datetime_negative",
809
- "datetime_int",
810
- "datetime_int_end",
811
- "datetime_datetime_end",
812
- "datetime_datetime",
813
- ],
814
- )
815
- def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng, request):
816
- mod = request.getfixturevalue(mod_name)
817
- idata = request.getfixturevalue(idata_name)
818
- time_idx = mod._get_fit_time_index()
819
- is_datetime = isinstance(time_idx, pd.DatetimeIndex)
820
-
821
- if isinstance(start, str):
822
- t0 = pd.Timestamp(start)
823
- elif isinstance(start, int):
824
- t0 = time_idx[start]
825
- else:
826
- t0 = time_idx[-1]
827
-
828
- delta = time_idx.freq if is_datetime else 1
829
-
830
- forecast_idata = mod.forecast(
831
- idata, start=start, end=end, periods=periods, filter_output=filter_output, random_seed=rng
832
- )
833
-
834
- forecast_idx = forecast_idata.coords["time"].values
835
- forecast_idx = pd.DatetimeIndex(forecast_idx) if is_datetime else pd.Index(forecast_idx)
836
-
837
- assert forecast_idx.shape == (10,)
838
- assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
839
- assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
840
-
841
- assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
842
- assert not np.any(np.isnan(forecast_idata.forecast_observed.values))
843
-
844
- assert forecast_idx[0] == (t0 + delta)
845
-
846
-
847
- @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
848
- @pytest.mark.parametrize("start", [None, -1, 10])
849
- def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
850
- scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
851
- scenario.iloc[5, 0] = 1e9
852
-
853
- forecast_idata = exog_ss_mod.forecast(
854
- idata_exog, start=start, periods=10, random_seed=rng, scenario=scenario
855
- )
856
-
857
- components = exog_ss_mod.extract_components_from_idata(forecast_idata)
858
- level = components.forecast_latent.sel(state="LevelTrend[level]")
859
- betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
860
-
861
- scenario.index.name = "time"
862
- scenario_xr = (
863
- scenario.unstack()
864
- .to_xarray()
865
- .rename({"level_0": "state"})
866
- .assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
867
- )
868
-
869
- regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
870
- regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
871
-
872
- assert_allclose(regression_effect, regression_effect_expected)