pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,670 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import pytensor.tensor as pt
6
+
7
+ from pytensor import graph_replace
8
+ from pytensor.tensor.slinalg import solve_discrete_lyapunov
9
+
10
+ from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
11
+ from pymc_extras.statespace.models.utilities import make_default_coords
12
+ from pymc_extras.statespace.utils.constants import (
13
+ ALL_STATE_AUX_DIM,
14
+ ALL_STATE_DIM,
15
+ ETS_SEASONAL_DIM,
16
+ OBS_STATE_AUX_DIM,
17
+ OBS_STATE_DIM,
18
+ )
19
+
20
+
21
+ class BayesianETS(PyMCStateSpace):
22
+ def __init__(
23
+ self,
24
+ order: tuple[str, str, str] | None = None,
25
+ endog_names: str | list[str] | None = None,
26
+ k_endog: int = 1,
27
+ trend: bool = True,
28
+ damped_trend: bool = False,
29
+ seasonal: bool = False,
30
+ seasonal_periods: int | None = None,
31
+ measurement_error: bool = False,
32
+ use_transformed_parameterization: bool = False,
33
+ dense_innovation_covariance: bool = False,
34
+ stationary_initialization: bool = False,
35
+ initialization_dampening: float = 0.8,
36
+ filter_type: str = "standard",
37
+ verbose: bool = True,
38
+ ):
39
+ r"""
40
+ Exponential Smoothing State Space Model
41
+
42
+ This class can represent a subset of exponential smoothing state space models, specifically those with additive
43
+ errors. Following .. [1], The general form of the model is:
44
+
45
+ .. math::
46
+
47
+ \begin{align}
48
+ y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
49
+ \epsilon_t &\sim N(0, \sigma)
50
+ \end{align}
51
+
52
+ where :math:`l_t` is the level component, :math:`b_t` is the trend component, and :math:`s_t` is the seasonal
53
+ component. These components can be included or excluded, leading to different model specifications. The following
54
+ models are possible:
55
+
56
+ * `ETS(A,N,N)`: Simple exponential smoothing
57
+
58
+ .. math::
59
+
60
+ \begin{align}
61
+ y_t &= l_{t-1} + \epsilon_t \\
62
+ l_t &= l_{t-1} + \alpha \epsilon_t
63
+ \end{align}
64
+
65
+ Where :math:`\alpha \in [0, 1]` is a mixing parameter between past observations and current innovations.
66
+ These equations arise by starting from the "component form":
67
+
68
+ .. math::
69
+
70
+ \begin{align}
71
+ \hat{y}_{t+1 | t} &= l_t \\
72
+ l_t &= \alpha y_t + (1 - \alpha) l_{t-1} \\
73
+ &= l_{t-1} + \alpha (y_t - l_{t-1})
74
+ &= l_{t-1} + \alpha \epsilon_t
75
+ \end{align}
76
+
77
+ Where $\epsilon_t$ are the forecast errors, assumed to be IID mean zero and normally distributed. The role of
78
+ :math:`\alpha` is clearest in the second line. The level of the time series at each time is a mixture of
79
+ :math:`\alpha` percent of the incoming data, and :math:`1 - \alpha` percent of the previous level. Recursive
80
+ substitution reveals that the level is a weighted composite of all previous observations; thus the name
81
+ "Exponential Smoothing".
82
+
83
+ Additional supposed specifications include:
84
+
85
+ * `ETS(A,A,N)`: Holt's linear trend method
86
+
87
+ .. math::
88
+
89
+ \begin{align}
90
+ y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
91
+ l_t &= l_{t-1} + b_{t-1} + \alpha \epsilon_t \\
92
+ b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t
93
+ \end{align}
94
+
95
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
96
+
97
+ * `ETS(A,N,A)`: Additive seasonal method
98
+
99
+ .. math::
100
+
101
+ \begin{align}
102
+ y_t &= l_{t-1} + s_{t-m} + \epsilon_t \\
103
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
104
+ s_t &= s_{t-m} + (1 - \alpha)\gamma^\star \epsilon_t
105
+ \end{align}
106
+
107
+ [1]_ also consider an alternative parameterization with :math:`\gamma = (1 - \alpha) \gamma^\star`.
108
+
109
+ * `ETS(A,A,A)`: Additive Holt-Winters method
110
+
111
+ .. math::
112
+
113
+ \begin{align}
114
+ y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
115
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
116
+ b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t \\
117
+ s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
118
+ \end{align}
119
+
120
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
121
+ :math:`\gamma = (1 - \alpha) \gamma^\star`.
122
+
123
+ * `ETS(A, Ad, N)`: Dampened trend method
124
+
125
+ .. math::
126
+
127
+ \begin{align}
128
+ y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
129
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
130
+ b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t
131
+ \end{align}
132
+
133
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
134
+
135
+ * `ETS(A, Ad, A)`: Dampened trend with seasonal method
136
+
137
+ .. math::
138
+
139
+ \begin{align}
140
+ y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
141
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
142
+ b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t \\
143
+ s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
144
+ \end{align}
145
+
146
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
147
+ :math:`\gamma = (1 - \alpha) \gamma^\star`.
148
+
149
+
150
+ Parameters
151
+ ----------
152
+ order: tuple of string, Optional
153
+ The exponential smoothing "order". This is a tuple of three strings, each of which should be one of 'A', 'Ad',
154
+ or 'N'.
155
+ If provided, the model will be initialized from the given order, and the `trend`, `damped_trend`, and `seasonal`
156
+ arguments will be ignored.
157
+ endog_names: str or list of str, Optional
158
+ Names associated with observed states. If a list, the length should be equal to the number of time series
159
+ to be estimated.
160
+ k_endog: int, Optional
161
+ Number of time series to estimate. If endog_names are provided, this is ignored and len(endog_names) is
162
+ used instead.
163
+ trend: bool
164
+ Whether to include a trend component. Setting ``trend=True`` is equivalent to ``order[1] == 'A'``.
165
+ damped_trend: bool
166
+ Whether to include a damping parameter on the trend component. Ignored if `trend` is `False`. Setting
167
+ ``trend=True`` and ``damped_trend=True`` is equivalent to order[1] == 'Ad'.
168
+ seasonal: bool
169
+ Whether to include a seasonal component. Setting ``seasonal=True`` is equivalent to ``order[2] = 'A'``.
170
+ seasonal_periods: int
171
+ The number of periods in a complete seasonal cycle. Ignored if `seasonal` is `False`
172
+ (or if ``order[2] == "N"``)
173
+ measurement_error: bool
174
+ Whether to include a measurement error term in the model. Default is `False`.
175
+ use_transformed_parameterization: bool, default False
176
+ If true, use the :math:`\alpha, \beta, \gamma` parameterization, otherwise use the :math:`\alpha, \beta^\star,
177
+ \gamma^\star` parameterization. This will change the admissible region for the priors.
178
+
179
+ - Under the **non-transformed** parameterization, all of :math:`\alpha, \beta^\star, \gamma^\star` should be
180
+ between 0 and 1.
181
+ - Under the **transformed** parameterization, :math:`\alpha \in (0, 1)`, :math:`\beta \in (0, \alpha)`, and
182
+ :math:`\gamma \in (0, 1 - \alpha)`
183
+
184
+ The :meth:`param_info` method will change to reflect the suggested intervals based on the value of this
185
+ argument.
186
+ dense_innovation_covariance: bool, default False
187
+ Whether to estimate a dense covariance for statespace innovations. In an ETS models, each observed variable
188
+ has a single source of stochastic variation. If True, these innovations are allowed to be correlated.
189
+ Ignored if ``k_endog == 1``
190
+ stationary_initialization: bool, default False
191
+ If True, the Kalman Filter's initial covariance matrix will be set to an approximate steady-state value.
192
+ The approximation is formed by adding a small dampening factor to each state. Specifically, the level state
193
+ for a ('A', 'N', 'N') model is written:
194
+
195
+ .. math::
196
+ \ell_t = \ell_{t-1} + \alpha * e_t
197
+
198
+ That this system is not stationary can be understood in ARIMA terms: the level is a random walk; that is,
199
+ :math:`rho = 1`. This can be remedied by pretending that we instead have a dampened system:
200
+
201
+ .. math::
202
+ \ell_t = \rho \ell_{t-1} + \alpha * e_t
203
+
204
+ With :math:`\rho \approx 1`, the system is stationary, and we can solve for the steady-state covariance
205
+ matrix. This is then used as the initial covariance matrix for the Kalman Filter. This is a heuristic
206
+ method that helps avoid setting a prior on the initial covariance matrix.
207
+ initialization_dampening: float, default 0.8
208
+ Dampening factor to add to non-stationary model components. This is only used for initialization, it does
209
+ *not* add dampening to the model. Ignored if `stationary_initialization` is `False`.
210
+ filter_type: str, default "standard"
211
+ The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
212
+ and "cholesky". See the docs for kalman filters for more details.
213
+ verbose: bool, default True
214
+ If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
215
+
216
+ References
217
+ ----------
218
+ .. [1] Hyndman, Rob J., and George Athanasopoulos. Forecasting: principles and practice. OTexts, 2018.
219
+ """
220
+
221
+ if order is not None:
222
+ if len(order) != 3 or any(not isinstance(o, str) for o in order):
223
+ raise ValueError("Order must be a tuple of three strings.")
224
+ if order[0] != "A":
225
+ raise ValueError("Only additive errors are supported.")
226
+ if order[1] not in {"A", "Ad", "N"}:
227
+ raise ValueError(
228
+ f"Invalid trend specification. Only 'A' (additive), 'Ad' (additive with dampening), "
229
+ f"or 'N' (no trend) are allowed. Found {order[1]}"
230
+ )
231
+ if order[2] not in {"A", "N"}:
232
+ raise ValueError(
233
+ f"Invalid seasonal specification. Only 'A' (additive) or 'N' (no seasonal component) "
234
+ f"are allowed. Found {order[2]}"
235
+ )
236
+
237
+ trend = order[1] != "N"
238
+ damped_trend = order[1] == "Ad"
239
+ seasonal = order[2] == "A"
240
+
241
+ self.trend = trend
242
+ self.damped_trend = damped_trend
243
+ self.seasonal = seasonal
244
+ self.seasonal_periods = seasonal_periods
245
+ self.use_transformed_parameterization = use_transformed_parameterization
246
+ self.stationary_initialization = stationary_initialization
247
+
248
+ if not (0.0 < initialization_dampening < 1.0):
249
+ raise ValueError(
250
+ "Dampening term used for initialization must be between 0 and 1 (preferably close to"
251
+ "1.0)"
252
+ )
253
+
254
+ self.initialization_dampening = initialization_dampening
255
+
256
+ if self.seasonal and self.seasonal_periods is None:
257
+ raise ValueError("If seasonal is True, seasonal_periods must be provided.")
258
+
259
+ if endog_names is not None:
260
+ endog_names = list(endog_names)
261
+ k_endog = len(endog_names)
262
+ else:
263
+ endog_names = [f"data_{i}" for i in range(k_endog)] if k_endog > 1 else ["data"]
264
+
265
+ self.endog_names = endog_names
266
+
267
+ if dense_innovation_covariance and k_endog == 1:
268
+ dense_innovation_covariance = False
269
+
270
+ self.dense_innovation_covariance = dense_innovation_covariance
271
+
272
+ k_states = (
273
+ 2
274
+ + int(trend)
275
+ + int(seasonal) * (seasonal_periods if seasonal_periods is not None else 0)
276
+ ) * k_endog
277
+
278
+ k_posdef = k_endog
279
+
280
+ super().__init__(
281
+ k_endog,
282
+ k_states,
283
+ k_posdef,
284
+ filter_type,
285
+ verbose=verbose,
286
+ measurement_error=measurement_error,
287
+ )
288
+
289
+ @property
290
+ def param_names(self):
291
+ names = [
292
+ "initial_level",
293
+ "initial_trend",
294
+ "initial_seasonal",
295
+ "P0",
296
+ "alpha",
297
+ "beta",
298
+ "gamma",
299
+ "phi",
300
+ "sigma_state",
301
+ "state_cov",
302
+ "sigma_obs",
303
+ ]
304
+ if not self.trend:
305
+ names.remove("initial_trend")
306
+ names.remove("beta")
307
+ if not self.damped_trend:
308
+ names.remove("phi")
309
+ if not self.seasonal:
310
+ names.remove("initial_seasonal")
311
+ names.remove("gamma")
312
+ if not self.measurement_error:
313
+ names.remove("sigma_obs")
314
+
315
+ if self.dense_innovation_covariance:
316
+ names.remove("sigma_state")
317
+ else:
318
+ names.remove("state_cov")
319
+
320
+ if self.stationary_initialization:
321
+ names.remove("P0")
322
+
323
+ return names
324
+
325
+ @property
326
+ def param_info(self) -> dict[str, dict[str, Any]]:
327
+ info = {
328
+ "P0": {
329
+ "shape": (self.k_states, self.k_states),
330
+ "constraints": "Positive Semi-definite",
331
+ },
332
+ "initial_level": {
333
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
334
+ "constraints": None,
335
+ },
336
+ "initial_trend": {
337
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
338
+ "constraints": None,
339
+ },
340
+ "initial_seasonal": {"shape": (self.seasonal_periods,), "constraints": None},
341
+ "sigma_obs": {
342
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
343
+ "constraints": "Positive",
344
+ },
345
+ "sigma_state": {
346
+ "shape": None if self.k_posdef == 1 else (self.k_posdef,),
347
+ "constraints": "Positive",
348
+ },
349
+ "alpha": {
350
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
351
+ "constraints": "0 < alpha < 1",
352
+ },
353
+ "beta": {
354
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
355
+ "constraints": "0 < beta < 1"
356
+ if not self.use_transformed_parameterization
357
+ else "0 < beta < alpha",
358
+ },
359
+ "gamma": {
360
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
361
+ "constraints": "0 < gamma< 1"
362
+ if not self.use_transformed_parameterization
363
+ else "0 < gamma < (1 - alpha)",
364
+ },
365
+ "phi": {
366
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
367
+ "constraints": "0 < phi < 1",
368
+ },
369
+ }
370
+
371
+ if self.dense_innovation_covariance:
372
+ del info["sigma_state"]
373
+ info["state_cov"] = {
374
+ "shape": (self.k_posdef, self.k_posdef),
375
+ "constraints": "Positive Semi-definite",
376
+ }
377
+
378
+ for name in self.param_names:
379
+ info[name]["dims"] = self.param_dims.get(name, None)
380
+
381
+ return {name: info[name] for name in self.param_names}
382
+
383
+ @property
384
+ def state_names(self):
385
+ states = ["innovation", "level"]
386
+ if self.trend:
387
+ states += ["trend"]
388
+ if self.seasonal:
389
+ states += ["seasonality"]
390
+ states += [f"L{i}.season" for i in range(1, self.seasonal_periods)]
391
+
392
+ if self.k_endog > 1:
393
+ states = [f"{name}_{state}" for name in self.endog_names for state in states]
394
+
395
+ return states
396
+
397
+ @property
398
+ def observed_states(self):
399
+ return self.endog_names
400
+
401
+ @property
402
+ def shock_names(self):
403
+ return (
404
+ ["innovation"]
405
+ if self.k_endog == 1
406
+ else [f"{name}_innovation" for name in self.endog_names]
407
+ )
408
+
409
+ @property
410
+ def param_dims(self):
411
+ coord_map = {
412
+ "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
413
+ "sigma_obs": (OBS_STATE_DIM,),
414
+ "sigma_state": (OBS_STATE_DIM,),
415
+ "initial_level": (OBS_STATE_DIM,),
416
+ "initial_trend": (OBS_STATE_DIM,),
417
+ "initial_seasonal": (ETS_SEASONAL_DIM,),
418
+ "seasonal_param": (ETS_SEASONAL_DIM,),
419
+ }
420
+
421
+ if self.dense_innovation_covariance:
422
+ del coord_map["sigma_state"]
423
+ coord_map["state_cov"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
424
+
425
+ if self.k_endog == 1:
426
+ coord_map["sigma_state"] = None
427
+ coord_map["sigma_obs"] = None
428
+ coord_map["initial_level"] = None
429
+ coord_map["initial_trend"] = None
430
+ else:
431
+ coord_map["alpha"] = (OBS_STATE_DIM,)
432
+ coord_map["beta"] = (OBS_STATE_DIM,)
433
+ coord_map["gamma"] = (OBS_STATE_DIM,)
434
+ coord_map["phi"] = (OBS_STATE_DIM,)
435
+ coord_map["initial_seasonal"] = (OBS_STATE_DIM, ETS_SEASONAL_DIM)
436
+ coord_map["seasonal_param"] = (OBS_STATE_DIM, ETS_SEASONAL_DIM)
437
+
438
+ if not self.measurement_error:
439
+ del coord_map["sigma_obs"]
440
+ if not self.seasonal:
441
+ del coord_map["seasonal_param"]
442
+
443
+ return coord_map
444
+
445
+ @property
446
+ def coords(self) -> dict[str, Sequence]:
447
+ coords = make_default_coords(self)
448
+ if self.seasonal:
449
+ coords.update({ETS_SEASONAL_DIM: list(range(1, self.seasonal_periods + 1))})
450
+
451
+ return coords
452
+
453
+ def _stationary_initialization(self, T_stationary):
454
+ # Solve for matrix quadratic for P0
455
+ R = self.ssm["selection"]
456
+ Q = self.ssm["state_cov"]
457
+
458
+ # ETS models are not stationary, but we can proceed *as if* the model were stationary by introducing large
459
+ # dampening factors on all components. We then set the initial covariance to the steady-state of that system,
460
+ # which we hope is similar enough to give a good initialization for the non-stationary system.
461
+
462
+ T_stationary = pt.specify_shape(T_stationary, (self.k_states, self.k_states))
463
+ P0 = solve_discrete_lyapunov(T_stationary, pt.linalg.matrix_dot(R, Q, R.T))
464
+ P0 = pt.specify_shape(P0, (self.k_states, self.k_states))
465
+
466
+ return P0
467
+
468
+ def make_symbolic_graph(self) -> None:
469
+ k_states_each = self.k_states // self.k_endog
470
+
471
+ initial_level = self.make_and_register_variable(
472
+ "initial_level", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
473
+ )
474
+
475
+ initial_states = [pt.zeros(k_states_each) for _ in range(self.k_endog)]
476
+ if self.k_endog == 1:
477
+ initial_states = [pt.set_subtensor(initial_states[0][1], initial_level)]
478
+ else:
479
+ initial_states = [
480
+ pt.set_subtensor(initial_state[1], initial_level[i])
481
+ for i, initial_state in enumerate(initial_states)
482
+ ]
483
+
484
+ # The shape of R can be pre-allocated, then filled with the required parameters
485
+ R = pt.zeros((self.k_states // self.k_endog, 1))
486
+
487
+ alpha = self.make_and_register_variable(
488
+ "alpha", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
489
+ )
490
+
491
+ # This is a dummy value for initialization. When we do a stationary initialization, it will be set to a value
492
+ # close to 1. Otherwise, it will be 1. We do not want this value to exist outside of this method.
493
+ stationary_dampening = pt.scalar("dampen_dummy")
494
+
495
+ if self.k_endog == 1:
496
+ # The R[0, 0] entry needs to be adjusted for a shift in the time indices. Consider the (A, N, N) model:
497
+ # y_t = l_{t-1} + e_t
498
+ # l_t = l_{t-1} + alpha * e_t
499
+ R_list = [pt.set_subtensor(R[1, 0], alpha)] # and l_t = ... + alpha * e_t
500
+
501
+ # We want the first equation to be in terms of time t on the RHS, because our observation equation is always
502
+ # y_t = Z @ x_t. Re-arranging equation 2, we get l_{t-1} = l_t - alpha * e_t --> y_t = l_t + e_t - alpha * e_t
503
+ # --> y_t = l_t + (1 - alpha) * e_t
504
+ R_list = [pt.set_subtensor(R[0, :], (1 - alpha)) for R in R_list]
505
+ else:
506
+ # If there are multiple endog, clone the basic R matrix and modify the appropriate entries
507
+ R_list = [pt.set_subtensor(R[1, 0], alpha[i]) for i in range(self.k_endog)]
508
+ R_list = [pt.set_subtensor(R[0, :], (1 - alpha[i])) for i, R in enumerate(R_list)]
509
+
510
+ # Shock and level component always exists, the base case is e_t = e_t and l_t = l_{t-1}
511
+ T_base = pt.set_subtensor(pt.zeros((2, 2))[1, 1], stationary_dampening)
512
+
513
+ if self.trend:
514
+ initial_trend = self.make_and_register_variable(
515
+ "initial_trend", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
516
+ )
517
+
518
+ if self.k_endog == 1:
519
+ initial_states = [pt.set_subtensor(initial_states[0][2], initial_trend)]
520
+ else:
521
+ initial_states = [
522
+ pt.set_subtensor(initial_state[2], initial_trend[i])
523
+ for i, initial_state in enumerate(initial_states)
524
+ ]
525
+ beta = self.make_and_register_variable(
526
+ "beta", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
527
+ )
528
+ if self.use_transformed_parameterization:
529
+ param = beta
530
+ else:
531
+ param = alpha * beta
532
+ if self.k_endog == 1:
533
+ R_list = [pt.set_subtensor(R[2, 0], param) for R in R_list]
534
+ else:
535
+ R_list = [pt.set_subtensor(R[2, 0], param[i]) for i, R in enumerate(R_list)]
536
+
537
+ # If a trend is requested, we have the following transition equations (omitting the shocks):
538
+ # l_t = l_{t-1} + b_{t-1}
539
+ # b_t = b_{t-1}
540
+ T_base = pt.as_tensor_variable(([0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]))
541
+ T_base = pt.set_subtensor(T_base[[1, 2], [1, 2]], stationary_dampening)
542
+
543
+ if self.damped_trend:
544
+ phi = self.make_and_register_variable(
545
+ "phi", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
546
+ )
547
+ # We are always in the case where we have a trend, so we can add the dampening parameter to T_base defined
548
+ # in that branch. Transition equations become:
549
+ # l_t = l_{t-1} + phi * b_{t-1}
550
+ # b_t = phi * b_{t-1}
551
+ if self.k_endog > 1:
552
+ T_base = [pt.set_subtensor(T_base[1:, 2], phi[i]) for i in range(self.k_endog)]
553
+ else:
554
+ T_base = pt.set_subtensor(T_base[1:, 2], phi)
555
+
556
+ T_components = (
557
+ [T_base for _ in range(self.k_endog)] if not isinstance(T_base, list) else T_base
558
+ )
559
+
560
+ if self.seasonal:
561
+ initial_seasonal = self.make_and_register_variable(
562
+ "initial_seasonal",
563
+ shape=(self.seasonal_periods,)
564
+ if self.k_endog == 1
565
+ else (self.k_endog, self.seasonal_periods),
566
+ dtype=floatX,
567
+ )
568
+ if self.k_endog == 1:
569
+ initial_states = [
570
+ pt.set_subtensor(initial_states[0][2 + int(self.trend) :], initial_seasonal)
571
+ ]
572
+ else:
573
+ initial_states = [
574
+ pt.set_subtensor(initial_state[2 + int(self.trend) :], initial_seasonal[i])
575
+ for i, initial_state in enumerate(initial_states)
576
+ ]
577
+
578
+ gamma = self.make_and_register_variable(
579
+ "gamma", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
580
+ )
581
+
582
+ param = gamma if self.use_transformed_parameterization else (1 - alpha) * gamma
583
+ # Additional adjustment to the R[0, 0] position is required. Start from:
584
+ # y_t = l_{t-1} + s_{t-m} + e_t
585
+ # l_t = l_{t-1} + alpha * e_t
586
+ # s_t = s_{t-m} + gamma * e_t
587
+ # Solve for l_{t-1} and s_{t-m} in terms of l_t and s_t, then substitute into the observation equation:
588
+ # y_t = l_t + s_t - alpha * e_t - gamma * e_t + e_t --> y_t = l_t + s_t + (1 - alpha - gamma) * e_t
589
+
590
+ if self.k_endog == 1:
591
+ R_list = [pt.set_subtensor(R[2 + int(self.trend), 0], param) for R in R_list]
592
+ R_list = [pt.set_subtensor(R[0, 0], R[0, 0] - param) for R in R_list]
593
+
594
+ else:
595
+ R_list = [
596
+ pt.set_subtensor(R[2 + int(self.trend), 0], param[i])
597
+ for i, R in enumerate(R_list)
598
+ ]
599
+ R_list = [
600
+ pt.set_subtensor(R[0, 0], R[0, 0] - param[i]) for i, R in enumerate(R_list)
601
+ ]
602
+
603
+ # The seasonal component is always going to look like a TimeFrequency structural component, see that
604
+ # docstring for more details
605
+ T_seasonals = [pt.eye(self.seasonal_periods, k=-1) for _ in range(self.k_endog)]
606
+ T_seasonals = [
607
+ pt.set_subtensor(T_seasonal[0, -1], stationary_dampening)
608
+ for T_seasonal in T_seasonals
609
+ ]
610
+
611
+ # Organize the components so it goes T1, T_seasonal_1, T2, T_seasonal_2, etc.
612
+ T_components = [
613
+ matrix[i] for i in range(self.k_endog) for matrix in [T_components, T_seasonals]
614
+ ]
615
+
616
+ x0 = pt.concatenate(initial_states, axis=0)
617
+ R = pt.linalg.block_diag(*R_list)
618
+
619
+ self.ssm["initial_state"] = x0
620
+ self.ssm["selection"] = pt.specify_shape(R, shape=(self.k_states, self.k_posdef))
621
+
622
+ T = pt.linalg.block_diag(*T_components)
623
+
624
+ # Remove the stationary_dampening dummies before saving the transition matrix
625
+ self.ssm["transition"] = pt.specify_shape(
626
+ graph_replace(T, {stationary_dampening: 1.0}), (self.k_states, self.k_states)
627
+ )
628
+
629
+ Zs = [np.zeros((self.k_endog, self.k_states // self.k_endog)) for _ in range(self.k_endog)]
630
+ for i, Z in enumerate(Zs):
631
+ Z[i, 0] = 1.0 # innovation
632
+ Z[i, 1] = 1.0 # level
633
+ if self.seasonal:
634
+ Z[i, 2 + int(self.trend)] = 1.0
635
+
636
+ Z = pt.concatenate(Zs, axis=1)
637
+
638
+ self.ssm["design"] = Z
639
+
640
+ # Set up the state covariance matrix
641
+ if self.dense_innovation_covariance:
642
+ state_cov = self.make_and_register_variable(
643
+ "state_cov", shape=(self.k_posdef, self.k_posdef), dtype=floatX
644
+ )
645
+ self.ssm["state_cov"] = state_cov
646
+
647
+ else:
648
+ state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
649
+ state_cov = self.make_and_register_variable(
650
+ "sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
651
+ )
652
+ self.ssm[state_cov_idx] = state_cov**2
653
+
654
+ if self.measurement_error:
655
+ obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))
656
+ obs_cov = self.make_and_register_variable(
657
+ "sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
658
+ )
659
+ self.ssm[obs_cov_idx] = obs_cov**2
660
+
661
+ if self.stationary_initialization:
662
+ T_stationary = graph_replace(T, {stationary_dampening: self.initialization_dampening})
663
+ P0 = self._stationary_initialization(T_stationary)
664
+
665
+ else:
666
+ P0 = self.make_and_register_variable(
667
+ "P0", shape=(self.k_states, self.k_states), dtype=floatX
668
+ )
669
+
670
+ self.ssm["initial_state_cov"] = P0