pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,818 @@
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pymc as pm
6
+ import pytensor
7
+ import pytensor.tensor as pt
8
+ import pytest
9
+
10
+ from numpy.testing import assert_allclose
11
+
12
+ from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
13
+ from pymc_extras.statespace.models import structural as st
14
+ from pymc_extras.statespace.models.utilities import make_default_coords
15
+ from pymc_extras.statespace.utils.constants import (
16
+ FILTER_OUTPUT_NAMES,
17
+ MATRIX_NAMES,
18
+ SMOOTHER_OUTPUT_NAMES,
19
+ )
20
+ from tests.statespace.utilities.shared_fixtures import (
21
+ rng,
22
+ )
23
+ from tests.statespace.utilities.test_helpers import (
24
+ fast_eval,
25
+ load_nile_test_data,
26
+ make_test_inputs,
27
+ )
28
+
29
+ floatX = pytensor.config.floatX
30
+ nile = load_nile_test_data()
31
+ ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
32
+
33
+
34
+ def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
35
+ class StateSpace(PyMCStateSpace):
36
+ def make_symbolic_graph(self):
37
+ pass
38
+
39
+ @property
40
+ def data_info(self):
41
+ return data_info
42
+
43
+ ss = StateSpace(
44
+ k_states=k_states,
45
+ k_endog=k_endog,
46
+ k_posdef=k_posdef,
47
+ filter_type=filter_type,
48
+ verbose=verbose,
49
+ )
50
+ ss._needs_exog_data = data_info is not None
51
+ ss._exog_names = list(data_info.keys()) if data_info is not None else []
52
+
53
+ return ss
54
+
55
+
56
+ @pytest.fixture(scope="session")
57
+ def ss_mod():
58
+ class StateSpace(PyMCStateSpace):
59
+ @property
60
+ def param_names(self):
61
+ return ["rho", "zeta"]
62
+
63
+ @property
64
+ def state_names(self):
65
+ return ["a", "b"]
66
+
67
+ @property
68
+ def observed_states(self):
69
+ return ["a"]
70
+
71
+ @property
72
+ def shock_names(self):
73
+ return ["a"]
74
+
75
+ @property
76
+ def coords(self):
77
+ return make_default_coords(self)
78
+
79
+ def make_symbolic_graph(self):
80
+ rho = self.make_and_register_variable("rho", ())
81
+ zeta = self.make_and_register_variable("zeta", ())
82
+ self.ssm["transition", 0, 0] = rho
83
+ self.ssm["transition", 1, 0] = zeta
84
+
85
+ Z = np.array([[1.0, 0.0]], dtype=floatX)
86
+ R = np.array([[1.0], [0.0]], dtype=floatX)
87
+ H = np.array([[0.1]], dtype=floatX)
88
+ Q = np.array([[0.8]], dtype=floatX)
89
+ P0 = np.eye(2, dtype=floatX) * 1e6
90
+
91
+ ss_mod = StateSpace(
92
+ k_endog=nile.shape[1], k_states=2, k_posdef=1, filter_type="standard", verbose=False
93
+ )
94
+ for X, name in zip(
95
+ [Z, R, H, Q, P0],
96
+ ["design", "selection", "obs_cov", "state_cov", "initial_state_cov"],
97
+ ):
98
+ ss_mod.ssm[name] = X
99
+
100
+ return ss_mod
101
+
102
+
103
+ @pytest.fixture(scope="session")
104
+ def pymc_mod(ss_mod):
105
+ with pm.Model(coords=ss_mod.coords) as pymc_mod:
106
+ rho = pm.Beta("rho", 1, 1)
107
+ zeta = pm.Deterministic("zeta", 1 - rho)
108
+
109
+ ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True)
110
+ names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
111
+ for name, matrix in zip(names, ss_mod.unpack_statespace()):
112
+ pm.Deterministic(name, matrix)
113
+
114
+ return pymc_mod
115
+
116
+
117
+ @pytest.fixture(scope="session")
118
+ def ss_mod_no_exog(rng):
119
+ ll = st.LevelTrendComponent(order=2, innovations_order=1)
120
+ return ll.build()
121
+
122
+
123
+ @pytest.fixture(scope="session")
124
+ def ss_mod_no_exog_dt(rng):
125
+ ll = st.LevelTrendComponent(order=2, innovations_order=1)
126
+ return ll.build()
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def exog_ss_mod(rng):
131
+ ll = st.LevelTrendComponent()
132
+ reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
133
+ mod = (ll + reg).build(verbose=False)
134
+
135
+ return mod
136
+
137
+
138
+ @pytest.fixture(scope="session")
139
+ def exog_pymc_mod(exog_ss_mod, rng):
140
+ y = rng.normal(size=(100, 1)).astype(floatX)
141
+ X = rng.normal(size=(100, 3)).astype(floatX)
142
+
143
+ with pm.Model(coords=exog_ss_mod.coords) as m:
144
+ exog_data = pm.Data("data_exog", X)
145
+ initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
146
+ P0_sigma = pm.Exponential("P0_sigma", 1)
147
+ P0 = pm.Deterministic(
148
+ "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
149
+ )
150
+ beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
151
+
152
+ sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
153
+ exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
154
+
155
+ return m
156
+
157
+
158
+ @pytest.fixture(scope="session")
159
+ def pymc_mod_no_exog(ss_mod_no_exog, rng):
160
+ y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
161
+
162
+ with pm.Model(coords=ss_mod_no_exog.coords) as m:
163
+ initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
164
+ P0_sigma = pm.Exponential("P0_sigma", 1)
165
+ P0 = pm.Deterministic(
166
+ "P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
167
+ )
168
+ sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
169
+ ss_mod_no_exog.build_statespace_graph(y)
170
+
171
+ return m
172
+
173
+
174
+ @pytest.fixture(scope="session")
175
+ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
176
+ y = pd.DataFrame(
177
+ rng.normal(size=(100, 1)).astype(floatX),
178
+ columns=["y"],
179
+ index=pd.date_range("2020-01-01", periods=100, freq="D"),
180
+ )
181
+
182
+ with pm.Model(coords=ss_mod_no_exog_dt.coords) as m:
183
+ initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
184
+ P0_sigma = pm.Exponential("P0_sigma", 1)
185
+ P0 = pm.Deterministic(
186
+ "P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
187
+ )
188
+ sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
189
+ ss_mod_no_exog_dt.build_statespace_graph(y)
190
+
191
+ return m
192
+
193
+
194
+ @pytest.fixture(scope="session")
195
+ def idata(pymc_mod, rng):
196
+ with pymc_mod:
197
+ idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
198
+ idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
199
+
200
+ idata.extend(idata_prior)
201
+ return idata
202
+
203
+
204
+ @pytest.fixture(scope="session")
205
+ def idata_exog(exog_pymc_mod, rng):
206
+ with exog_pymc_mod:
207
+ idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
208
+ idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
209
+ idata.extend(idata_prior)
210
+ return idata
211
+
212
+
213
+ @pytest.fixture(scope="session")
214
+ def idata_no_exog(pymc_mod_no_exog, rng):
215
+ with pymc_mod_no_exog:
216
+ idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
217
+ idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
218
+ idata.extend(idata_prior)
219
+ return idata
220
+
221
+
222
+ @pytest.fixture(scope="session")
223
+ def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
224
+ with pymc_mod_no_exog_dt:
225
+ idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
226
+ idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
227
+ idata.extend(idata_prior)
228
+ return idata
229
+
230
+
231
+ def test_invalid_filter_name_raises():
232
+ msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys()))
233
+ with pytest.raises(NotImplementedError, match=msg):
234
+ mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter")
235
+
236
+
237
+ def test_unpack_before_insert_raises(rng):
238
+ p, m, r, n = 2, 5, 1, 10
239
+ data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
240
+ mod = make_statespace_mod(
241
+ k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
242
+ )
243
+
244
+ msg = "Cannot unpack the complete statespace system until PyMC model variables have been inserted."
245
+ with pytest.raises(ValueError, match=msg):
246
+ outputs = mod.unpack_statespace()
247
+
248
+
249
+ def test_unpack_matrices(rng):
250
+ p, m, r, n = 2, 5, 1, 10
251
+ data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)
252
+ mod = make_statespace_mod(
253
+ k_endog=p, k_states=m, k_posdef=r, filter_type="standard", verbose=False
254
+ )
255
+
256
+ # mod is a dummy statespace, so there are no placeholders to worry about. Monkey patch subbed_ssm with the defaults
257
+ mod.subbed_ssm = mod._unpack_statespace_with_placeholders()
258
+
259
+ outputs = mod.unpack_statespace()
260
+ for x, y in zip(inputs, outputs):
261
+ assert_allclose(np.zeros_like(x), fast_eval(y))
262
+
263
+
264
+ def test_param_names_raises_on_base_class():
265
+ mod = make_statespace_mod(
266
+ k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
267
+ )
268
+ with pytest.raises(NotImplementedError):
269
+ x = mod.param_names
270
+
271
+
272
+ def test_base_class_raises():
273
+ with pytest.raises(NotImplementedError):
274
+ mod = PyMCStateSpace(
275
+ k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False
276
+ )
277
+
278
+
279
+ def test_update_raises_if_missing_variables(ss_mod):
280
+ with pm.Model() as mod:
281
+ rho = pm.Normal("rho")
282
+ msg = "The following required model parameters were not found in the PyMC model: zeta"
283
+ with pytest.raises(ValueError, match=msg):
284
+ ss_mod._insert_random_variables()
285
+
286
+
287
+ def test_build_statespace_graph_warns_if_data_has_nans():
288
+ # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
289
+ ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
290
+
291
+ with pm.Model() as pymc_mod:
292
+ initial_trend = pm.Normal("initial_trend", shape=(1,))
293
+ P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
294
+ with pytest.warns(pm.ImputationWarning):
295
+ ss_mod.build_statespace_graph(
296
+ data=np.full((10, 1), np.nan, dtype=floatX), register_data=False
297
+ )
298
+
299
+
300
+ def test_build_statespace_graph_raises_if_data_has_missing_fill():
301
+ # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over
302
+ ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
303
+
304
+ with pm.Model() as pymc_mod:
305
+ initial_trend = pm.Normal("initial_trend", shape=(1,))
306
+ P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
307
+ with pytest.raises(ValueError, match="Provided data contains the value 1.0"):
308
+ data = np.ones((10, 1), dtype=floatX)
309
+ data[3] = np.nan
310
+ ss_mod.build_statespace_graph(data=data, missing_fill_value=1.0, register_data=False)
311
+
312
+
313
+ def test_build_statespace_graph(pymc_mod):
314
+ for name in [
315
+ "filtered_state",
316
+ "predicted_state",
317
+ "predicted_covariance",
318
+ "filtered_covariance",
319
+ ]:
320
+ assert name in [x.name for x in pymc_mod.deterministics]
321
+
322
+
323
+ def test_build_smoother_graph(ss_mod, pymc_mod):
324
+ names = ["smoothed_state", "smoothed_covariance"]
325
+ for name in names:
326
+ assert name in [x.name for x in pymc_mod.deterministics]
327
+
328
+
329
+ @pytest.mark.parametrize("group", ["posterior", "prior"])
330
+ @pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
331
+ def test_no_nans_in_sampling_output(group, matrix, idata):
332
+ assert not np.any(np.isnan(idata[group][matrix].values))
333
+
334
+
335
+ @pytest.mark.parametrize("group", ["posterior", "prior"])
336
+ @pytest.mark.parametrize("kind", ["conditional", "unconditional"])
337
+ def test_sampling_methods(group, kind, ss_mod, idata, rng):
338
+ f = getattr(ss_mod, f"sample_{kind}_{group}")
339
+ test_idata = f(idata, random_seed=rng)
340
+
341
+ if kind == "conditional":
342
+ for output in ["filtered", "predicted", "smoothed"]:
343
+ assert f"{output}_{group}" in test_idata
344
+ assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
345
+ assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
346
+ if kind == "unconditional":
347
+ for output in ["latent", "observed"]:
348
+ assert f"{group}_{output}" in test_idata
349
+ assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
350
+
351
+
352
+ def _make_time_idx(mod, use_datetime_index=True):
353
+ if use_datetime_index:
354
+ mod._fit_coords["time"] = nile.index
355
+ time_idx = nile.index
356
+ else:
357
+ mod._fit_coords["time"] = nile.reset_index().index
358
+ time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1)
359
+
360
+ return time_idx
361
+
362
+
363
+ @pytest.mark.parametrize("use_datetime_index", [True, False])
364
+ def test_bad_forecast_arguments(use_datetime_index, caplog):
365
+ ss_mod = make_statespace_mod(
366
+ k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
367
+ )
368
+
369
+ # Not-fit model raises
370
+ ss_mod._fit_coords = dict()
371
+ with pytest.raises(ValueError, match="Has this model been fit?"):
372
+ ss_mod._get_fit_time_index()
373
+
374
+ time_idx = _make_time_idx(ss_mod, use_datetime_index)
375
+
376
+ # Start value not in time index
377
+ match = (
378
+ "Datetime start must be in the data index used to fit the model"
379
+ if use_datetime_index
380
+ else "Integer start must be within the range of the data index used to fit the model."
381
+ )
382
+ with pytest.raises(ValueError, match=match):
383
+ start = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11
384
+ ss_mod._validate_forecast_args(time_index=time_idx, start=start, periods=10)
385
+
386
+ # End value cannot be inferred
387
+ with pytest.raises(ValueError, match="Must specify one of either periods or end"):
388
+ start = time_idx[-1]
389
+ ss_mod._validate_forecast_args(time_index=time_idx, start=start)
390
+
391
+ # Unnecessary args warn on verbose
392
+ start = time_idx[-1]
393
+ forecast_idx = pd.date_range(start=start, periods=10, freq="YS-JAN")
394
+ scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
395
+
396
+ ss_mod._validate_forecast_args(
397
+ time_index=time_idx, start=start, periods=10, scenario=scenario, use_scenario_index=True
398
+ )
399
+ last_message = caplog.messages[-1]
400
+ assert "start, end, and periods arguments are ignored" in last_message
401
+
402
+ # Verbose=False silences warning
403
+ ss_mod._validate_forecast_args(
404
+ time_index=time_idx,
405
+ start=start,
406
+ periods=10,
407
+ scenario=scenario,
408
+ use_scenario_index=True,
409
+ verbose=False,
410
+ )
411
+ assert len(caplog.messages) == 1
412
+
413
+
414
+ @pytest.mark.parametrize("use_datetime_index", [True, False])
415
+ def test_forecast_index(use_datetime_index):
416
+ ss_mod = make_statespace_mod(
417
+ k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False
418
+ )
419
+ ss_mod._fit_coords = dict()
420
+ time_idx = _make_time_idx(ss_mod, use_datetime_index)
421
+
422
+ # From start and end
423
+ start = time_idx[-1]
424
+ delta = pd.DateOffset(years=10) if use_datetime_index else 11
425
+ end = start + delta
426
+
427
+ x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
428
+ assert start not in forecast_idx
429
+ assert x0_index == start
430
+ assert forecast_idx.shape == (10,)
431
+
432
+ # From start and periods
433
+ start = time_idx[-1]
434
+ periods = 10
435
+
436
+ x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
437
+ assert start not in forecast_idx
438
+ assert x0_index == start
439
+ assert forecast_idx.shape == (10,)
440
+
441
+ # From integer start
442
+ start = 10
443
+ x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
444
+ delta = forecast_idx.freq if use_datetime_index else 1
445
+
446
+ assert x0_index == time_idx[start]
447
+ assert forecast_idx.shape == (10,)
448
+ assert (forecast_idx == time_idx[start + 1 : start + periods + 1]).all()
449
+
450
+ # From scenario index
451
+ scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
452
+ new_start, forecast_idx = ss_mod._build_forecast_index(
453
+ time_index=time_idx, scenario=scenario, use_scenario_index=True
454
+ )
455
+ assert x0_index not in forecast_idx
456
+ assert x0_index == (forecast_idx[0] - delta)
457
+ assert forecast_idx.shape == (10,)
458
+ assert forecast_idx.equals(scenario.index)
459
+
460
+ # From dictionary of scenarios
461
+ scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])}
462
+ x0_index, forecast_idx = ss_mod._build_forecast_index(
463
+ time_index=time_idx, scenario=scenario, use_scenario_index=True
464
+ )
465
+ assert x0_index == (forecast_idx[0] - delta)
466
+ assert forecast_idx.shape == (10,)
467
+ assert forecast_idx.equals(scenario["a"].index)
468
+
469
+
470
+ @pytest.mark.parametrize(
471
+ "data_type",
472
+ [pd.Series, pd.DataFrame, np.array, list, tuple],
473
+ ids=["series", "dataframe", "array", "list", "tuple"],
474
+ )
475
+ def test_validate_scenario(data_type):
476
+ if data_type is pd.DataFrame:
477
+ # Ensure dataframes have the correct column name
478
+ data_type = partial(pd.DataFrame, columns=["column_1"])
479
+
480
+ # One data case
481
+ data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
482
+ ss_mod = make_statespace_mod(
483
+ k_endog=1,
484
+ k_posdef=1,
485
+ k_states=2,
486
+ filter_type="standard",
487
+ verbose=False,
488
+ data_info=data_info,
489
+ )
490
+ ss_mod._fit_coords = dict(features_a=["column_1"])
491
+
492
+ scenario = data_type(np.zeros(10))
493
+ scenario = ss_mod._validate_scenario_data(scenario)
494
+
495
+ # Lists and tuples are cast to 2d arrays
496
+ if data_type in [tuple, list]:
497
+ assert isinstance(scenario, np.ndarray)
498
+ assert scenario.shape == (10, 1)
499
+
500
+ # A one-item dictionary should also work
501
+ scenario = {"a": scenario}
502
+ ss_mod._validate_scenario_data(scenario)
503
+
504
+ # Now data has to be a dictionary
505
+ data_info.update({"b": {"shape": (None, 1), "dims": ("time", "features_b")}})
506
+ ss_mod = make_statespace_mod(
507
+ k_endog=1,
508
+ k_posdef=1,
509
+ k_states=2,
510
+ filter_type="standard",
511
+ verbose=False,
512
+ data_info=data_info,
513
+ )
514
+ ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1"])
515
+
516
+ scenario = {"a": data_type(np.zeros(10)), "b": data_type(np.zeros(10))}
517
+ ss_mod._validate_scenario_data(scenario)
518
+
519
+ # Mixed data types
520
+ data_info.update({"a": {"shape": (None, 10), "dims": ("time", "features_a")}})
521
+ ss_mod = make_statespace_mod(
522
+ k_endog=1,
523
+ k_posdef=1,
524
+ k_states=2,
525
+ filter_type="standard",
526
+ verbose=False,
527
+ data_info=data_info,
528
+ )
529
+ ss_mod._fit_coords = dict(
530
+ features_a=[f"column_{i}" for i in range(10)], features_b=["column_1"]
531
+ )
532
+
533
+ scenario = {
534
+ "a": pd.DataFrame(np.zeros((10, 10)), columns=ss_mod._fit_coords["features_a"]),
535
+ "b": data_type(np.arange(10)),
536
+ }
537
+
538
+ ss_mod._validate_scenario_data(scenario)
539
+
540
+
541
+ @pytest.mark.parametrize(
542
+ "data_type",
543
+ [pd.Series, pd.DataFrame, np.array, list, tuple],
544
+ ids=["series", "dataframe", "array", "list", "tuple"],
545
+ )
546
+ @pytest.mark.parametrize("use_datetime_index", [True, False])
547
+ def test_finalize_scenario_single(data_type, use_datetime_index):
548
+ if data_type is pd.DataFrame:
549
+ # Ensure dataframes have the correct column name
550
+ data_type = partial(pd.DataFrame, columns=["column_1"])
551
+
552
+ data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
553
+ ss_mod = make_statespace_mod(
554
+ k_endog=1,
555
+ k_posdef=1,
556
+ k_states=2,
557
+ filter_type="standard",
558
+ verbose=False,
559
+ data_info=data_info,
560
+ )
561
+ ss_mod._fit_coords = dict(features_a=["column_1"])
562
+
563
+ time_idx = _make_time_idx(ss_mod, use_datetime_index)
564
+
565
+ scenario = data_type(np.zeros((10,)))
566
+
567
+ scenario = ss_mod._validate_scenario_data(scenario)
568
+ t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
569
+ scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
570
+
571
+ assert isinstance(scenario, pd.DataFrame)
572
+ assert scenario.index.equals(forecast_idx)
573
+ assert scenario.columns == ["column_1"]
574
+
575
+
576
+ @pytest.mark.parametrize(
577
+ "data_type",
578
+ [pd.Series, pd.DataFrame, np.array, list, tuple],
579
+ ids=["series", "dataframe", "array", "list", "tuple"],
580
+ )
581
+ @pytest.mark.parametrize("use_datetime_index", [True, False])
582
+ @pytest.mark.parametrize("use_scenario_index", [True, False])
583
+ def test_finalize_secenario_dict(data_type, use_datetime_index, use_scenario_index):
584
+ data_info = {
585
+ "a": {"shape": (None, 1), "dims": ("time", "features_a")},
586
+ "b": {"shape": (None, 2), "dims": ("time", "features_b")},
587
+ }
588
+ ss_mod = make_statespace_mod(
589
+ k_endog=1,
590
+ k_posdef=1,
591
+ k_states=2,
592
+ filter_type="standard",
593
+ verbose=False,
594
+ data_info=data_info,
595
+ )
596
+ ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1", "column_2"])
597
+ time_idx = _make_time_idx(ss_mod, use_datetime_index)
598
+
599
+ initial_index = (
600
+ pd.date_range(start=time_idx[-1], periods=10, freq=time_idx.freq)
601
+ if use_datetime_index
602
+ else pd.RangeIndex(time_idx[-1], time_idx[-1] + 10, 1)
603
+ )
604
+
605
+ if data_type is pd.DataFrame:
606
+ # Ensure dataframes have the correct column name
607
+ data_type = partial(pd.DataFrame, columns=["column_1"], index=initial_index)
608
+ elif data_type is pd.Series:
609
+ data_type = partial(pd.Series, index=initial_index)
610
+
611
+ scenario = {
612
+ "a": data_type(np.zeros((10,))),
613
+ "b": pd.DataFrame(
614
+ np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"], index=initial_index
615
+ ),
616
+ }
617
+
618
+ scenario = ss_mod._validate_scenario_data(scenario)
619
+
620
+ if use_scenario_index and data_type not in [np.array, list, tuple]:
621
+ t0, forecast_idx = ss_mod._build_forecast_index(
622
+ time_idx, scenario=scenario, periods=10, use_scenario_index=True
623
+ )
624
+ elif use_scenario_index and data_type in [np.array, list, tuple]:
625
+ t0, forecast_idx = ss_mod._build_forecast_index(
626
+ time_idx, scenario=scenario, start=-1, periods=10, use_scenario_index=True
627
+ )
628
+ else:
629
+ t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
630
+
631
+ scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
632
+
633
+ assert list(scenario.keys()) == ["a", "b"]
634
+ assert all(isinstance(value, pd.DataFrame) for value in scenario.values())
635
+ assert all(value.index.equals(forecast_idx) for value in scenario.values())
636
+
637
+
638
+ def test_invalid_scenarios():
639
+ data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}}
640
+ ss_mod = make_statespace_mod(
641
+ k_endog=1,
642
+ k_posdef=1,
643
+ k_states=2,
644
+ filter_type="standard",
645
+ verbose=False,
646
+ data_info=data_info,
647
+ )
648
+ ss_mod._fit_coords = dict(features_a=["column_1", "column_2"])
649
+
650
+ # Omitting the data raises
651
+ with pytest.raises(
652
+ ValueError, match="This model was fit using exogenous data. Forecasting cannot be performed"
653
+ ):
654
+ ss_mod._validate_scenario_data(None)
655
+
656
+ # Giving a list, tuple, or Series when a matrix of data is expected should always raise
657
+ with pytest.raises(
658
+ ValueError,
659
+ match="Scenario data for variable 'a' has the wrong number of columns. "
660
+ "Expected 2, got 1",
661
+ ):
662
+ for data_type in [list, tuple, pd.Series]:
663
+ ss_mod._validate_scenario_data(data_type(np.zeros(10)))
664
+ ss_mod._validate_scenario_data({"a": data_type(np.zeros(10))})
665
+
666
+ # Providing irrevelant data raises
667
+ with pytest.raises(
668
+ ValueError,
669
+ match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
670
+ ):
671
+ ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
672
+
673
+ # Incorrect 2nd dimension of a non-dataframe
674
+ with pytest.raises(
675
+ ValueError,
676
+ match="Scenario data for variable 'a' has the wrong number of columns. Expected "
677
+ "2, got 1",
678
+ ):
679
+ scenario = np.zeros(10).tolist()
680
+ ss_mod._validate_scenario_data(scenario)
681
+ ss_mod._validate_scenario_data(tuple(scenario))
682
+
683
+ scenario = {"a": np.zeros(10).tolist()}
684
+ ss_mod._validate_scenario_data(scenario)
685
+ ss_mod._validate_scenario_data({"a": tuple(scenario["a"])})
686
+
687
+ # If a data frame is provided, it needs to have all columns
688
+ with pytest.raises(
689
+ ValueError, match="Scenario data for variable 'a' is missing the following column: column_2"
690
+ ):
691
+ scenario = pd.DataFrame(np.zeros((10, 1)), columns=["column_1"])
692
+ ss_mod._validate_scenario_data(scenario)
693
+
694
+ # Extra columns also raises
695
+ with pytest.raises(
696
+ ValueError,
697
+ match="Scenario data for variable 'a' contains the following extra columns "
698
+ "that are not used by the model: column_3, column_4",
699
+ ):
700
+ scenario = pd.DataFrame(
701
+ np.zeros((10, 4)), columns=["column_1", "column_2", "column_3", "column_4"]
702
+ )
703
+ ss_mod._validate_scenario_data(scenario)
704
+
705
+ # Wrong number of time steps raises
706
+ data_info = {
707
+ "a": {"shape": (None, 1), "dims": ("time", "features_a")},
708
+ "b": {"shape": (None, 1), "dims": ("time", "features_b")},
709
+ }
710
+ ss_mod = make_statespace_mod(
711
+ k_endog=1,
712
+ k_posdef=1,
713
+ k_states=2,
714
+ filter_type="standard",
715
+ verbose=False,
716
+ data_info=data_info,
717
+ )
718
+ ss_mod._fit_coords = dict(
719
+ features_a=["column_1", "column_2"], features_b=["column_1", "column_2"]
720
+ )
721
+
722
+ with pytest.raises(
723
+ ValueError, match="Scenario data must have the same number of time steps for all variables"
724
+ ):
725
+ scenario = {
726
+ "a": pd.DataFrame(np.zeros((10, 2)), columns=ss_mod._fit_coords["features_a"]),
727
+ "b": pd.DataFrame(np.zeros((11, 2)), columns=ss_mod._fit_coords["features_b"]),
728
+ }
729
+ ss_mod._validate_scenario_data(scenario)
730
+
731
+
732
+ @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
733
+ @pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
734
+ @pytest.mark.parametrize(
735
+ "mod_name, idata_name, start, end, periods",
736
+ [
737
+ ("ss_mod_no_exog", "idata_no_exog", None, None, 10),
738
+ ("ss_mod_no_exog", "idata_no_exog", -1, None, 10),
739
+ ("ss_mod_no_exog", "idata_no_exog", 10, None, 10),
740
+ ("ss_mod_no_exog", "idata_no_exog", 10, 21, None),
741
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", None, None, 10),
742
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", -1, None, 10),
743
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, None, 10),
744
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, "2020-01-21", None),
745
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", "2020-03-11", None),
746
+ ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", None, 10),
747
+ ],
748
+ ids=[
749
+ "range_default",
750
+ "range_negative",
751
+ "range_int",
752
+ "range_end",
753
+ "datetime_default",
754
+ "datetime_negative",
755
+ "datetime_int",
756
+ "datetime_int_end",
757
+ "datetime_datetime_end",
758
+ "datetime_datetime",
759
+ ],
760
+ )
761
+ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng, request):
762
+ mod = request.getfixturevalue(mod_name)
763
+ idata = request.getfixturevalue(idata_name)
764
+ time_idx = mod._get_fit_time_index()
765
+ is_datetime = isinstance(time_idx, pd.DatetimeIndex)
766
+
767
+ if isinstance(start, str):
768
+ t0 = pd.Timestamp(start)
769
+ elif isinstance(start, int):
770
+ t0 = time_idx[start]
771
+ else:
772
+ t0 = time_idx[-1]
773
+
774
+ delta = time_idx.freq if is_datetime else 1
775
+
776
+ forecast_idata = mod.forecast(
777
+ idata, start=start, end=end, periods=periods, filter_output=filter_output, random_seed=rng
778
+ )
779
+
780
+ forecast_idx = forecast_idata.coords["time"].values
781
+ forecast_idx = pd.DatetimeIndex(forecast_idx) if is_datetime else pd.Index(forecast_idx)
782
+
783
+ assert forecast_idx.shape == (10,)
784
+ assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
785
+ assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
786
+
787
+ assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
788
+ assert not np.any(np.isnan(forecast_idata.forecast_observed.values))
789
+
790
+ assert forecast_idx[0] == (t0 + delta)
791
+
792
+
793
+ @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
794
+ @pytest.mark.parametrize("start", [None, -1, 10])
795
+ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
796
+ scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
797
+ scenario.iloc[5, 0] = 1e9
798
+
799
+ forecast_idata = exog_ss_mod.forecast(
800
+ idata_exog, start=start, periods=10, random_seed=rng, scenario=scenario
801
+ )
802
+
803
+ components = exog_ss_mod.extract_components_from_idata(forecast_idata)
804
+ level = components.forecast_latent.sel(state="LevelTrend[level]")
805
+ betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
806
+
807
+ scenario.index.name = "time"
808
+ scenario_xr = (
809
+ scenario.unstack()
810
+ .to_xarray()
811
+ .rename({"level_0": "state"})
812
+ .assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
813
+ )
814
+
815
+ regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
816
+ regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
817
+
818
+ assert_allclose(regression_effect, regression_effect_expected)