pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,829 @@
1
+ import functools as ft
2
+ import warnings
3
+
4
+ from collections import defaultdict
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pymc as pm
10
+ import pytensor
11
+ import pytensor.tensor as pt
12
+ import pytest
13
+ import statsmodels.api as sm
14
+
15
+ from numpy.testing import assert_allclose
16
+ from scipy import linalg
17
+
18
+ from pymc_extras.statespace import structural as st
19
+ from pymc_extras.statespace.utils.constants import (
20
+ ALL_STATE_AUX_DIM,
21
+ ALL_STATE_DIM,
22
+ AR_PARAM_DIM,
23
+ OBS_STATE_AUX_DIM,
24
+ OBS_STATE_DIM,
25
+ SHOCK_AUX_DIM,
26
+ SHOCK_DIM,
27
+ SHORT_NAME_TO_LONG,
28
+ )
29
+ from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
30
+ rng,
31
+ )
32
+ from tests.statespace.utilities.test_helpers import (
33
+ assert_pattern_repeats,
34
+ simulate_from_numpy_model,
35
+ unpack_symbolic_matrices_with_params,
36
+ )
37
+
38
+ floatX = pytensor.config.floatX
39
+ ATOL = 1e-8 if floatX.endswith("64") else 1e-4
40
+ RTOL = 0 if floatX.endswith("64") else 1e-6
41
+
42
+
43
+ def _assert_all_statespace_matrices_match(mod, params, sm_mod):
44
+ x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, params)
45
+
46
+ sm_x0, sm_H0, sm_P0 = sm_mod.initialization()
47
+
48
+ if len(x0) > 0:
49
+ assert_allclose(x0, sm_x0)
50
+
51
+ for name, matrix in zip(["T", "R", "Z", "Q"], [T, R, Z, Q]):
52
+ long_name = SHORT_NAME_TO_LONG[name]
53
+ if np.any([x == 0 for x in matrix.shape]):
54
+ continue
55
+ assert_allclose(
56
+ sm_mod.ssm[long_name],
57
+ matrix,
58
+ err_msg=f"matrix {name} does not match statsmodels",
59
+ atol=ATOL,
60
+ rtol=RTOL,
61
+ )
62
+
63
+
64
+ def _assert_coord_shapes_match_matrices(mod, params):
65
+ if "initial_state_cov" not in params:
66
+ params["initial_state_cov"] = np.eye(mod.k_states)
67
+
68
+ x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, params)
69
+
70
+ n_states = len(mod.coords[ALL_STATE_DIM])
71
+
72
+ # There will always be one shock dimension -- dummies are inserted into fully deterministic models to avoid errors
73
+ # in the state space representation.
74
+ n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
75
+ n_obs = len(mod.coords[OBS_STATE_DIM])
76
+
77
+ assert x0.shape[-1:] == (
78
+ n_states,
79
+ ), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
80
+ assert P0.shape[-2:] == (
81
+ n_states,
82
+ n_states,
83
+ ), f"P0 expected to have shape (n_states, n_states), found {P0.shape[-2:]}"
84
+ assert c.shape[-1:] == (
85
+ n_states,
86
+ ), f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
87
+ assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}"
88
+ assert T.shape[-2:] == (
89
+ n_states,
90
+ n_states,
91
+ ), f"T expected to have shape (n_states, n_states), found {T.shape[-2:]}"
92
+ assert Z.shape[-2:] == (
93
+ n_obs,
94
+ n_states,
95
+ ), f"Z expected to have shape (n_obs, n_states), found {Z.shape[-2:]}"
96
+ assert R.shape[-2:] == (
97
+ n_states,
98
+ n_shocks,
99
+ ), f"R expected to have shape (n_states, n_shocks), found {R.shape[-2:]}"
100
+ assert H.shape[-2:] == (
101
+ n_obs,
102
+ n_obs,
103
+ ), f"H expected to have shape (n_obs, n_obs), found {H.shape[-2:]}"
104
+ assert Q.shape[-2:] == (
105
+ n_shocks,
106
+ n_shocks,
107
+ ), f"Q expected to have shape (n_shocks, n_shocks), found {Q.shape[-2:]}"
108
+
109
+
110
+ def _assert_basic_coords_correct(mod):
111
+ assert mod.coords[ALL_STATE_DIM] == mod.state_names
112
+ assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names
113
+ assert mod.coords[SHOCK_DIM] == mod.shock_names
114
+ assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names
115
+ assert mod.coords[OBS_STATE_DIM] == ["data"]
116
+ assert mod.coords[OBS_STATE_AUX_DIM] == ["data"]
117
+
118
+
119
+ def _assert_keys_match(test_dict, expected_dict):
120
+ expected_keys = list(expected_dict.keys())
121
+ param_keys = list(test_dict.keys())
122
+ key_diff = set(expected_keys) - set(param_keys)
123
+ assert len(key_diff) == 0, f'{", ".join(key_diff)} were not found in the test_dict keys.'
124
+
125
+ key_diff = set(param_keys) - set(expected_keys)
126
+ assert (
127
+ len(key_diff) == 0
128
+ ), f'{", ".join(key_diff)} were keys of the tests_dict not in expected_dict.'
129
+
130
+
131
+ def _assert_param_dims_correct(param_dims, expected_dims):
132
+ if len(expected_dims) == 0 and len(param_dims) == 0:
133
+ return
134
+
135
+ _assert_keys_match(param_dims, expected_dims)
136
+ for param, dims in expected_dims.items():
137
+ assert dims == param_dims[param], f"dims for parameter {param} do not match"
138
+
139
+
140
+ def _assert_coords_correct(coords, expected_coords):
141
+ if len(coords) == 0 and len(expected_coords) == 0:
142
+ return
143
+
144
+ _assert_keys_match(coords, expected_coords)
145
+ for dim, labels in expected_coords.items():
146
+ assert labels == coords[dim], f"labels on dimension {dim} do not match"
147
+
148
+
149
+ def _assert_params_info_correct(param_info, coords, param_dims):
150
+ for param in param_info.keys():
151
+ info = param_info[param]
152
+
153
+ dims = info["dims"]
154
+ labels = [coords[dim] for dim in dims] if dims is not None else None
155
+ if labels is not None:
156
+ assert param in param_dims.keys()
157
+ inferred_dims = param_dims[param]
158
+ else:
159
+ inferred_dims = None
160
+
161
+ shape = tuple(len(label) for label in labels) if labels is not None else ()
162
+
163
+ assert info["shape"] == shape
164
+ assert dims == inferred_dims
165
+
166
+
167
+ def create_structural_model_and_equivalent_statsmodel(
168
+ rng,
169
+ level: bool | None = False,
170
+ trend: bool | None = False,
171
+ seasonal: int | None = None,
172
+ freq_seasonal: list[dict] | None = None,
173
+ cycle: bool = False,
174
+ autoregressive: int | None = None,
175
+ exog: np.ndarray | None = None,
176
+ irregular: bool | None = False,
177
+ stochastic_level: bool | None = True,
178
+ stochastic_trend: bool | None = False,
179
+ stochastic_seasonal: bool | None = True,
180
+ stochastic_freq_seasonal: list[bool] | None = None,
181
+ stochastic_cycle: bool | None = False,
182
+ damped_cycle: bool | None = False,
183
+ ):
184
+ with warnings.catch_warnings():
185
+ warnings.simplefilter("ignore")
186
+ mod = ft.partial(
187
+ sm.tsa.UnobservedComponents,
188
+ level=level,
189
+ trend=trend,
190
+ seasonal=seasonal,
191
+ freq_seasonal=freq_seasonal,
192
+ cycle=cycle,
193
+ autoregressive=autoregressive,
194
+ exog=exog,
195
+ irregular=irregular,
196
+ stochastic_level=stochastic_level,
197
+ stochastic_trend=stochastic_trend,
198
+ stochastic_seasonal=stochastic_seasonal,
199
+ stochastic_freq_seasonal=stochastic_freq_seasonal,
200
+ stochastic_cycle=stochastic_cycle,
201
+ damped_cycle=damped_cycle,
202
+ mle_regression=False,
203
+ )
204
+
205
+ params = {}
206
+ sm_params = {}
207
+ sm_init = {}
208
+ expected_param_dims = defaultdict(tuple)
209
+ expected_coords = defaultdict(list)
210
+ expected_param_dims["P0"] += ("state", "state_aux")
211
+
212
+ default_states = [
213
+ ALL_STATE_DIM,
214
+ ALL_STATE_AUX_DIM,
215
+ OBS_STATE_DIM,
216
+ OBS_STATE_AUX_DIM,
217
+ SHOCK_DIM,
218
+ SHOCK_AUX_DIM,
219
+ ]
220
+ default_values = [[], [], ["data"], ["data"], [], []]
221
+ for dim, value in zip(default_states, default_values):
222
+ expected_coords[dim] += value
223
+
224
+ components = []
225
+
226
+ if irregular:
227
+ sigma2 = np.abs(rng.normal()).astype(floatX).item()
228
+ params["sigma_irregular"] = np.sqrt(sigma2)
229
+ sm_params["sigma2.irregular"] = sigma2
230
+
231
+ comp = st.MeasurementError("irregular")
232
+ components.append(comp)
233
+
234
+ level_trend_order = [0, 0]
235
+ level_trend_innov_order = [0, 0]
236
+
237
+ if level:
238
+ level_trend_order[0] = 1
239
+ expected_coords["trend_state"] += [
240
+ "level",
241
+ ]
242
+ expected_coords[ALL_STATE_DIM] += [
243
+ "level",
244
+ ]
245
+ expected_coords[ALL_STATE_AUX_DIM] += [
246
+ "level",
247
+ ]
248
+ if stochastic_level:
249
+ level_trend_innov_order[0] = 1
250
+ expected_coords["trend_shock"] += ["level"]
251
+ expected_coords[SHOCK_DIM] += [
252
+ "level",
253
+ ]
254
+ expected_coords[SHOCK_AUX_DIM] += [
255
+ "level",
256
+ ]
257
+
258
+ if trend:
259
+ level_trend_order[1] = 1
260
+ expected_coords["trend_state"] += [
261
+ "trend",
262
+ ]
263
+ expected_coords[ALL_STATE_DIM] += [
264
+ "trend",
265
+ ]
266
+ expected_coords[ALL_STATE_AUX_DIM] += [
267
+ "trend",
268
+ ]
269
+
270
+ if stochastic_trend:
271
+ level_trend_innov_order[1] = 1
272
+ expected_coords["trend_shock"] += ["trend"]
273
+ expected_coords[SHOCK_DIM] += ["trend"]
274
+ expected_coords[SHOCK_AUX_DIM] += ["trend"]
275
+
276
+ if level or trend:
277
+ expected_param_dims["initial_trend"] += ("trend_state",)
278
+ level_value = np.where(
279
+ level_trend_order,
280
+ rng.normal(
281
+ size=2,
282
+ ).astype(floatX),
283
+ np.zeros(2, dtype=floatX),
284
+ )
285
+ sigma_level_value2 = np.abs(rng.normal(size=(2,)))[
286
+ np.array(level_trend_innov_order, dtype="bool")
287
+ ]
288
+ max_order = np.flatnonzero(level_value)[-1].item() + 1
289
+ level_trend_order = level_trend_order[:max_order]
290
+
291
+ params["initial_trend"] = level_value[:max_order]
292
+ sm_init["level"] = level_value[0]
293
+ sm_init["trend"] = level_value[1]
294
+
295
+ if sum(level_trend_innov_order) > 0:
296
+ expected_param_dims["sigma_trend"] += ("trend_shock",)
297
+ params["sigma_trend"] = np.sqrt(sigma_level_value2)
298
+
299
+ sigma_level_value = sigma_level_value2.tolist()
300
+ if stochastic_level:
301
+ sigma = sigma_level_value.pop(0)
302
+ sm_params["sigma2.level"] = sigma
303
+ if stochastic_trend:
304
+ sigma = sigma_level_value.pop(0)
305
+ sm_params["sigma2.trend"] = sigma
306
+
307
+ comp = st.LevelTrendComponent(
308
+ name="level", order=level_trend_order, innovations_order=level_trend_innov_order
309
+ )
310
+ components.append(comp)
311
+
312
+ if seasonal is not None:
313
+ state_names = [f"seasonal_{i}" for i in range(seasonal)][1:]
314
+ seasonal_coefs = rng.normal(size=(seasonal - 1,)).astype(floatX)
315
+ params["seasonal_coefs"] = seasonal_coefs
316
+ expected_param_dims["seasonal_coefs"] += ("seasonal_state",)
317
+
318
+ expected_coords["seasonal_state"] += tuple(state_names)
319
+ expected_coords[ALL_STATE_DIM] += state_names
320
+ expected_coords[ALL_STATE_AUX_DIM] += state_names
321
+
322
+ seasonal_dict = {
323
+ "seasonal" if i == 0 else f"seasonal.L{i}": c for i, c in enumerate(seasonal_coefs)
324
+ }
325
+ sm_init.update(seasonal_dict)
326
+
327
+ if stochastic_seasonal:
328
+ sigma2 = np.abs(rng.normal()).astype(floatX)
329
+ params["sigma_seasonal"] = np.sqrt(sigma2)
330
+ sm_params["sigma2.seasonal"] = sigma2
331
+ expected_coords[SHOCK_DIM] += [
332
+ "seasonal",
333
+ ]
334
+ expected_coords[SHOCK_AUX_DIM] += [
335
+ "seasonal",
336
+ ]
337
+
338
+ comp = st.TimeSeasonality(
339
+ name="seasonal", season_length=seasonal, innovations=stochastic_seasonal
340
+ )
341
+ components.append(comp)
342
+
343
+ if freq_seasonal is not None:
344
+ state_count = 0
345
+ for d, has_innov in zip(freq_seasonal, stochastic_freq_seasonal):
346
+ n = d["harmonics"]
347
+ s = d["period"]
348
+ last_state_not_identified = (s / n) == 2.0
349
+ n_states = 2 * n - int(last_state_not_identified)
350
+ state_names = [f"seasonal_{s}_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
351
+
352
+ seasonal_params = rng.normal(size=n_states).astype(floatX)
353
+
354
+ params[f"seasonal_{s}"] = seasonal_params
355
+ expected_param_dims[f"seasonal_{s}"] += (f"seasonal_{s}_state",)
356
+ expected_coords[ALL_STATE_DIM] += state_names
357
+ expected_coords[ALL_STATE_AUX_DIM] += state_names
358
+ expected_coords[f"seasonal_{s}_state"] += (
359
+ tuple(state_names[:-1]) if last_state_not_identified else tuple(state_names)
360
+ )
361
+
362
+ for param in seasonal_params:
363
+ sm_init[f"freq_seasonal.{state_count}"] = param
364
+ state_count += 1
365
+ if last_state_not_identified:
366
+ sm_init[f"freq_seasonal.{state_count}"] = 0.0
367
+ state_count += 1
368
+
369
+ if has_innov:
370
+ sigma2 = np.abs(rng.normal()).astype(floatX)
371
+ params[f"sigma_seasonal_{s}"] = np.sqrt(sigma2)
372
+ sm_params[f"sigma2.freq_seasonal_{s}({n})"] = sigma2
373
+ expected_coords[SHOCK_DIM] += state_names
374
+ expected_coords[SHOCK_AUX_DIM] += state_names
375
+
376
+ comp = st.FrequencySeasonality(
377
+ name=f"seasonal_{s}", season_length=s, n=n, innovations=has_innov
378
+ )
379
+ components.append(comp)
380
+
381
+ if cycle:
382
+ cycle_length = np.random.choice(np.arange(2, 12)).astype(floatX)
383
+
384
+ # Statsmodels takes the frequency not the cycle length, so convert it.
385
+ sm_params["frequency.cycle"] = 2.0 * np.pi / cycle_length
386
+ params["cycle_length"] = cycle_length
387
+
388
+ init_cycle = rng.normal(size=(2,)).astype(floatX)
389
+ params["cycle"] = init_cycle
390
+ expected_param_dims["cycle"] += ("cycle_state",)
391
+
392
+ state_names = ["cycle_Cos", "cycle_Sin"]
393
+ expected_coords["cycle_state"] += state_names
394
+ expected_coords[ALL_STATE_DIM] += state_names
395
+ expected_coords[ALL_STATE_AUX_DIM] += state_names
396
+
397
+ sm_init["cycle"] = init_cycle[0]
398
+ sm_init["cycle.auxilliary"] = init_cycle[1]
399
+
400
+ if stochastic_cycle:
401
+ sigma2 = np.abs(rng.normal()).astype(floatX)
402
+ params["sigma_cycle"] = np.sqrt(sigma2)
403
+ expected_coords[SHOCK_DIM] += state_names
404
+ expected_coords[SHOCK_AUX_DIM] += state_names
405
+
406
+ sm_params["sigma2.cycle"] = sigma2
407
+
408
+ if damped_cycle:
409
+ rho = rng.beta(1, 1)
410
+ params["cycle_dampening_factor"] = rho
411
+ sm_params["damping.cycle"] = rho
412
+
413
+ comp = st.CycleComponent(
414
+ name="cycle",
415
+ dampen=damped_cycle,
416
+ innovations=stochastic_cycle,
417
+ estimate_cycle_length=True,
418
+ )
419
+
420
+ components.append(comp)
421
+
422
+ if autoregressive is not None:
423
+ ar_names = [f"L{i+1}.data" for i in range(autoregressive)]
424
+ ar_params = rng.normal(size=(autoregressive,)).astype(floatX)
425
+ if autoregressive == 1:
426
+ ar_params = ar_params.item()
427
+ sigma2 = np.abs(rng.normal()).astype(floatX)
428
+
429
+ params["ar_params"] = ar_params
430
+ params["sigma_ar"] = np.sqrt(sigma2)
431
+ expected_param_dims["ar_params"] += (AR_PARAM_DIM,)
432
+ expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
433
+ expected_coords[ALL_STATE_DIM] += ar_names
434
+ expected_coords[ALL_STATE_AUX_DIM] += ar_names
435
+ expected_coords[SHOCK_DIM] += ["ar_innovation"]
436
+ expected_coords[SHOCK_AUX_DIM] += ["ar_innovation"]
437
+
438
+ sm_params["sigma2.ar"] = sigma2
439
+ for i, rho in enumerate(ar_params):
440
+ sm_init[f"ar.L{i+1}"] = 0
441
+ sm_params[f"ar.L{i+1}"] = rho
442
+
443
+ comp = st.AutoregressiveComponent(name="ar", order=autoregressive)
444
+ components.append(comp)
445
+
446
+ if exog is not None:
447
+ names = [f"x{i + 1}" for i in range(exog.shape[1])]
448
+ betas = rng.normal(size=(exog.shape[1],)).astype(floatX)
449
+ params["beta_exog"] = betas
450
+ params["data_exog"] = exog
451
+ expected_param_dims["beta_exog"] += ("exog_state",)
452
+ expected_param_dims["data_exog"] += ("time", "exog_data")
453
+
454
+ expected_coords["exog_state"] += tuple(names)
455
+
456
+ for i, beta in enumerate(betas):
457
+ sm_params[f"beta.x{i + 1}"] = beta
458
+ sm_init[f"beta.x{i+1}"] = beta
459
+ comp = st.RegressionComponent(name="exog", state_names=names)
460
+ components.append(comp)
461
+
462
+ st_mod = components.pop(0)
463
+ for comp in components:
464
+ st_mod += comp
465
+ return mod, st_mod, params, sm_params, sm_init, expected_param_dims, expected_coords
466
+
467
+
468
+ @pytest.mark.parametrize(
469
+ "level, trend, stochastic_level, stochastic_trend, irregular",
470
+ [
471
+ (False, False, False, False, True),
472
+ (True, True, True, True, True),
473
+ (True, True, False, True, False),
474
+ ],
475
+ )
476
+ @pytest.mark.parametrize("autoregressive", [None, 3])
477
+ @pytest.mark.parametrize("seasonal, stochastic_seasonal", [(None, False), (12, False), (12, True)])
478
+ @pytest.mark.parametrize(
479
+ "freq_seasonal, stochastic_freq_seasonal",
480
+ [
481
+ (None, None),
482
+ ([{"period": 12, "harmonics": 2}], [False]),
483
+ ([{"period": 12, "harmonics": 6}], [True]),
484
+ ],
485
+ )
486
+ @pytest.mark.parametrize(
487
+ "cycle, damped_cycle, stochastic_cycle",
488
+ [(False, False, False), (True, False, True), (True, True, True)],
489
+ )
490
+ @pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.ConvergenceWarning")
491
+ @pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.SpecificationWarning")
492
+ def test_structural_model_against_statsmodels(
493
+ level,
494
+ trend,
495
+ stochastic_level,
496
+ stochastic_trend,
497
+ irregular,
498
+ autoregressive,
499
+ seasonal,
500
+ stochastic_seasonal,
501
+ freq_seasonal,
502
+ stochastic_freq_seasonal,
503
+ cycle,
504
+ damped_cycle,
505
+ stochastic_cycle,
506
+ rng,
507
+ ):
508
+ retvals = create_structural_model_and_equivalent_statsmodel(
509
+ rng,
510
+ level=level,
511
+ trend=trend,
512
+ seasonal=seasonal,
513
+ freq_seasonal=freq_seasonal,
514
+ cycle=cycle,
515
+ damped_cycle=damped_cycle,
516
+ autoregressive=autoregressive,
517
+ irregular=irregular,
518
+ stochastic_level=stochastic_level,
519
+ stochastic_trend=stochastic_trend,
520
+ stochastic_seasonal=stochastic_seasonal,
521
+ stochastic_freq_seasonal=stochastic_freq_seasonal,
522
+ stochastic_cycle=stochastic_cycle,
523
+ )
524
+ f_sm_mod, mod, params, sm_params, sm_init, expected_dims, expected_coords = retvals
525
+
526
+ data = rng.normal(size=(100,)).astype(floatX)
527
+ sm_mod = f_sm_mod(data)
528
+
529
+ if len(sm_init) > 0:
530
+ init_array = np.concatenate(
531
+ [np.atleast_1d(sm_init[k]).ravel() for k in sm_mod.state_names if k != "dummy"]
532
+ )
533
+ sm_mod.initialize_known(init_array, np.eye(sm_mod.k_states))
534
+ else:
535
+ sm_mod.initialize_default()
536
+
537
+ if len(sm_params) > 0:
538
+ param_array = np.concatenate(
539
+ [np.atleast_1d(sm_params[k]).ravel() for k in sm_mod.param_names]
540
+ )
541
+ sm_mod.update(param_array, transformed=True)
542
+
543
+ _assert_all_statespace_matrices_match(mod, params, sm_mod)
544
+
545
+ built_model = mod.build(verbose=False)
546
+
547
+ _assert_coord_shapes_match_matrices(built_model, params)
548
+ _assert_param_dims_correct(built_model.param_dims, expected_dims)
549
+ _assert_coords_correct(built_model.coords, expected_coords)
550
+ _assert_params_info_correct(built_model.param_info, built_model.coords, built_model.param_dims)
551
+
552
+
553
+ def test_level_trend_model(rng):
554
+ mod = st.LevelTrendComponent(order=2, innovations_order=0)
555
+ params = {"initial_trend": [0.0, 1.0]}
556
+ x, y = simulate_from_numpy_model(mod, rng, params)
557
+
558
+ assert_allclose(np.diff(y), 1, atol=ATOL, rtol=RTOL)
559
+
560
+ # Check coords
561
+ mod = mod.build(verbose=False)
562
+ _assert_basic_coords_correct(mod)
563
+ assert mod.coords["trend_state"] == ["level", "trend"]
564
+
565
+
566
+ def test_measurement_error(rng):
567
+ mod = st.MeasurementError("obs") + st.LevelTrendComponent(order=2)
568
+ mod = mod.build(verbose=False)
569
+
570
+ _assert_basic_coords_correct(mod)
571
+ assert "sigma_obs" in mod.param_names
572
+
573
+
574
+ @pytest.mark.parametrize("order", [1, 2, [1, 0, 1]], ids=["AR1", "AR2", "AR(1,0,1)"])
575
+ def test_autoregressive_model(order, rng):
576
+ ar = st.AutoregressiveComponent(order=order)
577
+ params = {
578
+ "ar_params": np.full((sum(ar.order),), 0.5, dtype=floatX),
579
+ "sigma_ar": 0.0,
580
+ }
581
+
582
+ x, y = simulate_from_numpy_model(ar, rng, params, steps=100)
583
+
584
+ # Check coords
585
+ ar.build(verbose=False)
586
+ _assert_basic_coords_correct(ar)
587
+ lags = np.arange(len(order) if isinstance(order, list) else order, dtype="int") + 1
588
+ if isinstance(order, list):
589
+ lags = lags[np.flatnonzero(order)]
590
+ assert_allclose(ar.coords["ar_lag"], lags)
591
+
592
+
593
+ @pytest.mark.parametrize("s", [10, 25, 50])
594
+ @pytest.mark.parametrize("innovations", [True, False])
595
+ def test_time_seasonality(s, innovations, rng):
596
+ def random_word(rng):
597
+ return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
598
+
599
+ state_names = [random_word(rng) for _ in range(s)]
600
+ mod = st.TimeSeasonality(
601
+ season_length=s, innovations=innovations, name="season", state_names=state_names
602
+ )
603
+ x0 = np.zeros(mod.k_states, dtype=floatX)
604
+ x0[0] = 1
605
+
606
+ params = {"season_coefs": x0}
607
+ if mod.innovations:
608
+ params["sigma_season"] = 0.0
609
+
610
+ x, y = simulate_from_numpy_model(mod, rng, params)
611
+ y = y.ravel()
612
+ if not innovations:
613
+ assert_pattern_repeats(y, s, atol=ATOL, rtol=RTOL)
614
+
615
+ # Check coords
616
+ mod.build(verbose=False)
617
+ _assert_basic_coords_correct(mod)
618
+ assert mod.coords["season_state"] == state_names[1:]
619
+
620
+
621
+ def get_shift_factor(s):
622
+ s_str = str(s)
623
+ if "." not in s_str:
624
+ return 1
625
+ _, decimal = s_str.split(".")
626
+ return 10 ** len(decimal)
627
+
628
+
629
+ @pytest.mark.parametrize("n", [*np.arange(1, 6, dtype="int").tolist(), None])
630
+ @pytest.mark.parametrize("s", [5, 10, 25, 25.2])
631
+ def test_frequency_seasonality(n, s, rng):
632
+ mod = st.FrequencySeasonality(season_length=s, n=n, name="season")
633
+ x0 = rng.normal(size=mod.n_coefs).astype(floatX)
634
+ params = {"season": x0, "sigma_season": 0.0}
635
+ k = get_shift_factor(s)
636
+ T = int(s * k)
637
+
638
+ x, y = simulate_from_numpy_model(mod, rng, params, steps=2 * T)
639
+ assert_pattern_repeats(y, T, atol=ATOL, rtol=RTOL)
640
+
641
+ # Check coords
642
+ mod.build(verbose=False)
643
+ _assert_basic_coords_correct(mod)
644
+ if n is None:
645
+ n = int(s // 2)
646
+ states = [f"season_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
647
+
648
+ # Remove the last state when the model is completely saturated
649
+ if s / n == 2.0:
650
+ states.pop()
651
+ assert mod.coords["season_state"] == states
652
+
653
+
654
+ cycle_test_vals = zip([None, None, 3, 5, 10], [False, True, True, False, False])
655
+
656
+
657
+ def test_cycle_component_deterministic(rng):
658
+ cycle = st.CycleComponent(
659
+ name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False
660
+ )
661
+ params = {"cycle": np.array([1.0, 1.0], dtype=floatX)}
662
+ x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
663
+
664
+ assert_pattern_repeats(y, 12, atol=ATOL, rtol=RTOL)
665
+
666
+
667
+ def test_cycle_component_with_dampening(rng):
668
+ cycle = st.CycleComponent(
669
+ name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False, dampen=True
670
+ )
671
+ params = {"cycle": np.array([10.0, 10.0], dtype=floatX), "cycle_dampening_factor": 0.75}
672
+ x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
673
+
674
+ # Check that the cycle dampens to zero over time
675
+ assert_allclose(y[-1], 0.0, atol=ATOL, rtol=RTOL)
676
+
677
+
678
+ def test_cycle_component_with_innovations_and_cycle_length(rng):
679
+ cycle = st.CycleComponent(
680
+ name="cycle", estimate_cycle_length=True, innovations=True, dampen=True
681
+ )
682
+ params = {
683
+ "cycle": np.array([1.0, 1.0], dtype=floatX),
684
+ "cycle_length": 12.0,
685
+ "cycle_dampening_factor": 0.95,
686
+ "sigma_cycle": 1.0,
687
+ }
688
+
689
+ x, y = simulate_from_numpy_model(cycle, rng, params)
690
+
691
+ cycle.build(verbose=False)
692
+ _assert_basic_coords_correct(cycle)
693
+
694
+
695
+ def test_exogenous_component(rng):
696
+ data = rng.normal(size=(100, 2)).astype(floatX)
697
+ mod = st.RegressionComponent(state_names=["feature_1", "feature_2"], name="exog")
698
+
699
+ params = {"beta_exog": np.array([1.0, 2.0], dtype=floatX)}
700
+ exog_data = {"data_exog": data}
701
+ x, y = simulate_from_numpy_model(mod, rng, params, exog_data)
702
+
703
+ # Check that the generated data is just a linear regression
704
+ assert_allclose(y, data @ params["beta_exog"], atol=ATOL, rtol=RTOL)
705
+
706
+ mod.build(verbose=False)
707
+ _assert_basic_coords_correct(mod)
708
+ assert mod.coords["exog_state"] == ["feature_1", "feature_2"]
709
+
710
+
711
+ def test_adding_exogenous_component(rng):
712
+ data = rng.normal(size=(100, 2)).astype(floatX)
713
+ reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
714
+ ll = st.LevelTrendComponent(name="level")
715
+
716
+ seasonal = st.FrequencySeasonality(name="annual", season_length=12, n=4)
717
+ mod = reg + ll + seasonal
718
+
719
+ assert mod.ssm["design"].eval({"data_exog": data}).shape == (100, 1, 2 + 2 + 8)
720
+ assert_allclose(mod.ssm["design", 5, 0, :2].eval({"data_exog": data}), data[5])
721
+
722
+
723
+ def test_add_components():
724
+ ll = st.LevelTrendComponent(order=2)
725
+ se = st.TimeSeasonality(name="seasonal", season_length=12)
726
+ mod = ll + se
727
+
728
+ ll_params = {
729
+ "initial_trend": np.zeros(2, dtype=floatX),
730
+ "sigma_trend": np.ones(2, dtype=floatX),
731
+ }
732
+ se_params = {
733
+ "seasonal_coefs": np.ones(11, dtype=floatX),
734
+ "sigma_seasonal": 1.0,
735
+ }
736
+ all_params = ll_params.copy()
737
+ all_params.update(se_params)
738
+
739
+ (ll_x0, ll_P0, ll_c, ll_d, ll_T, ll_Z, ll_R, ll_H, ll_Q) = unpack_symbolic_matrices_with_params(
740
+ ll, ll_params
741
+ )
742
+ (se_x0, se_P0, se_c, se_d, se_T, se_Z, se_R, se_H, se_Q) = unpack_symbolic_matrices_with_params(
743
+ se, se_params
744
+ )
745
+ x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, all_params)
746
+
747
+ for property in ["param_names", "shock_names", "param_info", "coords", "param_dims"]:
748
+ assert [x in getattr(mod, property) for x in getattr(ll, property)]
749
+ assert [x in getattr(mod, property) for x in getattr(se, property)]
750
+
751
+ ll_mats = [ll_T, ll_R, ll_Q]
752
+ se_mats = [se_T, se_R, se_Q]
753
+ all_mats = [T, R, Q]
754
+
755
+ for ll_mat, se_mat, all_mat in zip(ll_mats, se_mats, all_mats):
756
+ assert_allclose(all_mat, linalg.block_diag(ll_mat, se_mat), atol=ATOL, rtol=RTOL)
757
+
758
+ ll_mats = [ll_x0, ll_c, ll_Z]
759
+ se_mats = [se_x0, se_c, se_Z]
760
+ all_mats = [x0, c, Z]
761
+ axes = [0, 0, 1]
762
+
763
+ for ll_mat, se_mat, all_mat, axis in zip(ll_mats, se_mats, all_mats, axes):
764
+ assert_allclose(all_mat, np.concatenate([ll_mat, se_mat], axis=axis), atol=ATOL, rtol=RTOL)
765
+
766
+
767
+ def test_filter_scans_time_varying_design_matrix(rng):
768
+ time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
769
+ data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)
770
+
771
+ y = pd.DataFrame(rng.normal(size=(100, 1)), columns=["data"], index=time_idx)
772
+
773
+ reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
774
+ mod = reg.build(verbose=False)
775
+
776
+ with pm.Model(coords=mod.coords) as m:
777
+ data_exog = pm.Data("data_exog", data.values)
778
+
779
+ x0 = pm.Normal("x0", dims=["state"])
780
+ P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
781
+ beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
782
+
783
+ mod.build_statespace_graph(y)
784
+ x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
785
+ pm.Deterministic("Z", Z)
786
+
787
+ prior = pm.sample_prior_predictive(draws=10)
788
+
789
+ prior_Z = prior.prior.Z.values
790
+ assert prior_Z.shape == (1, 10, 100, 1, 2)
791
+ assert_allclose(prior_Z[0, :, :, 0, :], data.values[None].repeat(10, axis=0))
792
+
793
+
794
+ @pytest.mark.skipif(floatX.endswith("32"), reason="Prior covariance not PSD at half-precision")
795
+ def test_extract_components_from_idata(rng):
796
+ time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
797
+ data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)
798
+
799
+ y = pd.DataFrame(rng.normal(size=(100, 1)), columns=["data"], index=time_idx)
800
+
801
+ ll = st.LevelTrendComponent()
802
+ season = st.FrequencySeasonality(name="seasonal", season_length=12, n=2, innovations=False)
803
+ reg = st.RegressionComponent(state_names=["a", "b"], name="exog")
804
+ me = st.MeasurementError("obs")
805
+ mod = (ll + season + reg + me).build(verbose=False)
806
+
807
+ with pm.Model(coords=mod.coords) as m:
808
+ data_exog = pm.Data("data_exog", data.values)
809
+
810
+ x0 = pm.Normal("x0", dims=["state"])
811
+ P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
812
+ beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
813
+ initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
814
+ sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
815
+ seasonal_coefs = pm.Normal("seasonal", dims=["seasonal_state"])
816
+ sigma_obs = pm.Exponential("sigma_obs", 1)
817
+
818
+ mod.build_statespace_graph(y)
819
+
820
+ x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
821
+ prior = pm.sample_prior_predictive(draws=10)
822
+
823
+ filter_prior = mod.sample_conditional_prior(prior)
824
+ comp_prior = mod.extract_components_from_idata(filter_prior)
825
+ comp_states = comp_prior.filtered_prior.coords["state"].values
826
+ expected_states = ["LevelTrend[level]", "LevelTrend[trend]", "seasonal", "exog[a]", "exog[b]"]
827
+ missing = set(comp_states) - set(expected_states)
828
+
829
+ assert len(missing) == 0, missing