pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytensor.tensor as pt
|
|
6
|
+
|
|
7
|
+
from pytensor.tensor.slinalg import solve_discrete_lyapunov
|
|
8
|
+
|
|
9
|
+
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
|
|
10
|
+
from pymc_extras.statespace.models.utilities import (
|
|
11
|
+
make_default_coords,
|
|
12
|
+
make_harvey_state_names,
|
|
13
|
+
make_SARIMA_transition_matrix,
|
|
14
|
+
)
|
|
15
|
+
from pymc_extras.statespace.utils.constants import (
|
|
16
|
+
ALL_STATE_AUX_DIM,
|
|
17
|
+
ALL_STATE_DIM,
|
|
18
|
+
AR_PARAM_DIM,
|
|
19
|
+
MA_PARAM_DIM,
|
|
20
|
+
OBS_STATE_DIM,
|
|
21
|
+
SARIMAX_STATE_STRUCTURES,
|
|
22
|
+
SEASONAL_AR_PARAM_DIM,
|
|
23
|
+
SEASONAL_MA_PARAM_DIM,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _verify_order(p, d, q, P, D, Q, S):
|
|
28
|
+
for name, terms in zip(["AR", "MA"], [(p, P), (q, Q)]):
|
|
29
|
+
a, A = terms
|
|
30
|
+
seasonal_lags = [(1 + i) * S for i in range(A)]
|
|
31
|
+
lags = [(1 + i) for i in range(a)]
|
|
32
|
+
overlapping_terms = set(seasonal_lags).intersection(set(lags))
|
|
33
|
+
if any(overlapping_terms):
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"The following {name} and seasonal {name} terms overlap, check model "
|
|
36
|
+
f"definition: {overlapping_terms}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BayesianSARIMA(PyMCStateSpace):
|
|
41
|
+
r"""
|
|
42
|
+
Seasonal AutoRegressive Integrated Moving Average with eXogenous regressors
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
order: tuple(int, int, int)
|
|
47
|
+
Order of the ARIMA process. The order has the notation (p, d, q), where p is the number of autoregressive
|
|
48
|
+
lags, q is the number of moving average components, and d is order of integration -- the number of
|
|
49
|
+
differences needed to render the data stationary.
|
|
50
|
+
|
|
51
|
+
If d > 0, the differences are modeled as components of the hidden state, and all available data can be used.
|
|
52
|
+
This is only possible if state_structure = 'fast'. For interpretable states, the user must manually
|
|
53
|
+
difference the data prior to calling the `build_statespace_graph` method.
|
|
54
|
+
|
|
55
|
+
seasonal_order: tuple(int, int, int, int), optional
|
|
56
|
+
Seasonal order of the SARIMA process. The order has the notation (P, D, Q, S), where P is the number of seasonal
|
|
57
|
+
lags to include, Q is the number of seasonal innovation lags to include, and D is the number of seasonal
|
|
58
|
+
differences to perform. S is the length of the season.
|
|
59
|
+
|
|
60
|
+
Seasonal terms are similar to ARIMA terms, in that they are merely lags of the data or innovations. It is thus
|
|
61
|
+
possible for the seasonal lags and the ARIMA lags to overlap, for example if P <= p. In this case, an error
|
|
62
|
+
will be raised.
|
|
63
|
+
|
|
64
|
+
stationary_initialization: bool, default False
|
|
65
|
+
If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
|
|
66
|
+
state values will be used.
|
|
67
|
+
|
|
68
|
+
.. warning:: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics
|
|
69
|
+
for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite"
|
|
70
|
+
error.
|
|
71
|
+
|
|
72
|
+
filter_type: str, default "standard"
|
|
73
|
+
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
|
|
74
|
+
and "cholesky". See the docs for kalman filters for more details.
|
|
75
|
+
|
|
76
|
+
state_structure: str, default "fast"
|
|
77
|
+
How to represent the state-space system. Currently, there are two choices: "fast" or "interpretable"
|
|
78
|
+
|
|
79
|
+
- "fast" corresponds to the state space used by [2], and is called the "Harvey" representation in statsmodels.
|
|
80
|
+
This is also the default representation used by statsmodels.tsa.statespace.SARIMAX. The states combine lags
|
|
81
|
+
and innovations at different lags to compress the dimension of the state vector to max(p, 1+q). As a result,
|
|
82
|
+
it is very preformat, but only the first state has a clear interpretation.
|
|
83
|
+
|
|
84
|
+
- "interpretable" maximally expands the state vector, doing zero state compression. As a result, the state has
|
|
85
|
+
dimension max(1, p) + max(1, q). What is gained by doing this is that every state has an obvious meaning, as
|
|
86
|
+
either the data, an innovation, or a lag thereof.
|
|
87
|
+
|
|
88
|
+
measurement_error: bool, default True
|
|
89
|
+
If true, a measurement error term is added to the model.
|
|
90
|
+
|
|
91
|
+
verbose: bool, default True
|
|
92
|
+
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
|
|
93
|
+
|
|
94
|
+
Notes
|
|
95
|
+
-----
|
|
96
|
+
The ARIMAX model is a univariate time series model that posits the future evolution of a stationary time series will
|
|
97
|
+
be a function of its past values, together with exogenous "innovations" and their past history. The model is
|
|
98
|
+
described by its "order", a 3-tuple (p, d, q), that are:
|
|
99
|
+
|
|
100
|
+
- p: The number of past time steps that directly influence the present value of the time series, called the
|
|
101
|
+
"autoregressive", or AR, component
|
|
102
|
+
- d: The "integration" order of the time series
|
|
103
|
+
- q: The number of past exogenous innovations that directly influence the present value of the time series,
|
|
104
|
+
called the "moving average", or MA, component
|
|
105
|
+
|
|
106
|
+
Given this 3-tuple, the model can be written:
|
|
107
|
+
|
|
108
|
+
.. math::
|
|
109
|
+
(1- \phi_1 B - \cdots - \phi_p B^p) (1-B)^d y_{t} = c + (1 + \theta_1 B + \cdots + \theta_q B^q) \varepsilon_t
|
|
110
|
+
|
|
111
|
+
Where B is the backshift operator, :math:`By_{t} = y_{t-1}`.
|
|
112
|
+
|
|
113
|
+
The model assumes that the data are stationary; that is, that they can be described by a time-invariant Gaussian
|
|
114
|
+
distribution with fixed mean and finite variance. Non-stationary data, those that grow over time, are not suitable
|
|
115
|
+
for ARIMA modeling without preprocessing. Stationary can be induced in any time series by the sequential application
|
|
116
|
+
of differences. Given a hypothetical non-stationary process:
|
|
117
|
+
|
|
118
|
+
.. math::
|
|
119
|
+
y_{t} = c + \rho y_{t-1} + \varepsilon_{t}
|
|
120
|
+
|
|
121
|
+
The process:
|
|
122
|
+
|
|
123
|
+
.. math::
|
|
124
|
+
\Delta y_{t} = y_{t} - y_{t-1} = \rho \Delta y_{t-1} + \Delta \varepsilon_t
|
|
125
|
+
|
|
126
|
+
is stationary, as the non-stationary component :math:`c` was eliminated by the operation of differencing. This
|
|
127
|
+
process is said to be "integrated of order 1", as it requires 1 difference to render stationary. This is the
|
|
128
|
+
function of the `d` parameter in the ARIMA order.
|
|
129
|
+
|
|
130
|
+
Alternatively, the non-stationary components can be directly estimated. In this case, the errors of a preliminary
|
|
131
|
+
regression are assumed to be ARIMA distributed, so that:
|
|
132
|
+
|
|
133
|
+
.. math::
|
|
134
|
+
\begin{align}
|
|
135
|
+
y_{t} &= X\beta + \eta_t \\
|
|
136
|
+
(1- \phi_1 B - \cdots - \phi_p B^p) (1-B)^d \eta_{t} &= (1 + \theta_1 B + \cdots + \theta_q B^q) \varepsilon_t
|
|
137
|
+
\end{align}
|
|
138
|
+
|
|
139
|
+
Where the design matrix `X` can include a constant, trends, or exogenous regressors.
|
|
140
|
+
|
|
141
|
+
ARIMA models can be represented in statespace form, as described in [1]. For more details, see chapters 3.4, 3.6,
|
|
142
|
+
and 8.4.
|
|
143
|
+
|
|
144
|
+
Examples
|
|
145
|
+
--------
|
|
146
|
+
The following example shows how to build an ARMA(1, 1) model -- ARIMA(1, 0, 1) -- using the BayesianSARIMA class:
|
|
147
|
+
|
|
148
|
+
.. code:: python
|
|
149
|
+
|
|
150
|
+
import pymc_extras.statespace as pmss
|
|
151
|
+
import pymc as pm
|
|
152
|
+
|
|
153
|
+
ss_mod = pmss.BayesianSARIMA(order=(1, 0, 1), verbose=True)
|
|
154
|
+
|
|
155
|
+
with pm.Model(coords=ss_mod.coords) as arma_model:
|
|
156
|
+
state_sigmas = pm.HalfNormal("sigma_state", sigma=1.0, dims=ss_mod.param_dims["sigma_state"])
|
|
157
|
+
|
|
158
|
+
rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"])
|
|
159
|
+
theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"])
|
|
160
|
+
|
|
161
|
+
ss_mod.build_statespace_graph(df, mode="JAX")
|
|
162
|
+
idata = pm.sample(nuts_sampler='numpyro')
|
|
163
|
+
|
|
164
|
+
References
|
|
165
|
+
----------
|
|
166
|
+
.. [1] Durbin, James, and Siem Jan Koopman. 2012.
|
|
167
|
+
Time Series Analysis by State Space Methods: Second Edition.
|
|
168
|
+
Oxford University Press.
|
|
169
|
+
|
|
170
|
+
.. [2] Harvey, A. C. (1989). Forecasting, Structural Time Series Models and the
|
|
171
|
+
Kalman Filter. Cambridge: Cambridge University Press.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
order: tuple[int, int, int],
|
|
177
|
+
seasonal_order: tuple[int, int, int, int] | None = None,
|
|
178
|
+
stationary_initialization: bool = True,
|
|
179
|
+
filter_type: str = "standard",
|
|
180
|
+
state_structure: str = "fast",
|
|
181
|
+
measurement_error: bool = False,
|
|
182
|
+
verbose=True,
|
|
183
|
+
):
|
|
184
|
+
# Model order
|
|
185
|
+
self.p, self.d, self.q = order
|
|
186
|
+
if seasonal_order is None:
|
|
187
|
+
seasonal_order = (0, 0, 0, 0)
|
|
188
|
+
|
|
189
|
+
self.P, self.D, self.Q, self.S = seasonal_order
|
|
190
|
+
_verify_order(self.p, self.d, self.q, self.P, self.D, self.Q, self.S)
|
|
191
|
+
|
|
192
|
+
self.stationary_initialization = stationary_initialization
|
|
193
|
+
|
|
194
|
+
self.state_structure = state_structure
|
|
195
|
+
|
|
196
|
+
self._p_max = max(1, self.p + self.P * self.S)
|
|
197
|
+
self._q_max = max(1, self.q + self.Q * self.S)
|
|
198
|
+
|
|
199
|
+
k_states = None
|
|
200
|
+
self._k_diffs = self.d + self.S * self.D
|
|
201
|
+
|
|
202
|
+
if state_structure not in SARIMAX_STATE_STRUCTURES:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Got invalid argument {state_structure} for state structure, expected one of "
|
|
205
|
+
f'{", ".join(SARIMAX_STATE_STRUCTURES)}'
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if state_structure == "interpretable" and (self.d + self.D) > 0:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
"Cannot use interpretable state structure with statespace differencing. Difference the "
|
|
211
|
+
'data by hand (leaving NaN values to be interpolated), or use state_structure="fast"'
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if self.state_structure == "fast":
|
|
215
|
+
k_states = max(self.p + self.P * self.S, self.q + self.Q * self.S + 1) + (
|
|
216
|
+
self.S * self.D + self.d
|
|
217
|
+
)
|
|
218
|
+
elif self.state_structure == "interpretable":
|
|
219
|
+
k_states = self._p_max + self._q_max
|
|
220
|
+
|
|
221
|
+
k_posdef = 1
|
|
222
|
+
k_endog = 1
|
|
223
|
+
|
|
224
|
+
super().__init__(
|
|
225
|
+
k_endog,
|
|
226
|
+
k_states,
|
|
227
|
+
k_posdef,
|
|
228
|
+
filter_type,
|
|
229
|
+
verbose=verbose,
|
|
230
|
+
measurement_error=measurement_error,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def param_names(self):
|
|
235
|
+
names = [
|
|
236
|
+
"x0",
|
|
237
|
+
"P0",
|
|
238
|
+
"ar_params",
|
|
239
|
+
"ma_params",
|
|
240
|
+
"seasonal_ar_params",
|
|
241
|
+
"seasonal_ma_params",
|
|
242
|
+
"sigma_state",
|
|
243
|
+
"sigma_obs",
|
|
244
|
+
]
|
|
245
|
+
if self.stationary_initialization:
|
|
246
|
+
names.remove("P0")
|
|
247
|
+
names.remove("x0")
|
|
248
|
+
if self.p == 0:
|
|
249
|
+
names.remove("ar_params")
|
|
250
|
+
if self.P == 0:
|
|
251
|
+
names.remove("seasonal_ar_params")
|
|
252
|
+
if self.q == 0:
|
|
253
|
+
names.remove("ma_params")
|
|
254
|
+
if self.Q == 0:
|
|
255
|
+
names.remove("seasonal_ma_params")
|
|
256
|
+
if not self.measurement_error:
|
|
257
|
+
names.remove("sigma_obs")
|
|
258
|
+
|
|
259
|
+
return names
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def param_info(self) -> dict[str, dict[str, Any]]:
|
|
263
|
+
info = {
|
|
264
|
+
"x0": {
|
|
265
|
+
"shape": (self.k_states,),
|
|
266
|
+
"constraints": None,
|
|
267
|
+
},
|
|
268
|
+
"P0": {
|
|
269
|
+
"shape": (self.k_states, self.k_states),
|
|
270
|
+
"constraints": "Positive Semi-definite",
|
|
271
|
+
},
|
|
272
|
+
"sigma_obs": {
|
|
273
|
+
"shape": None if self.k_endog == 1 else (self.k_endog,),
|
|
274
|
+
"constraints": "Positive",
|
|
275
|
+
},
|
|
276
|
+
"sigma_state": {
|
|
277
|
+
"shape": None if self.k_posdef == 1 else (self.k_posdef,),
|
|
278
|
+
"constraints": "Positive",
|
|
279
|
+
},
|
|
280
|
+
"ar_params": {
|
|
281
|
+
"shape": (self.p,),
|
|
282
|
+
"constraints": "None",
|
|
283
|
+
},
|
|
284
|
+
"ma_params": {
|
|
285
|
+
"shape": (self.q,),
|
|
286
|
+
"constraints": "None",
|
|
287
|
+
},
|
|
288
|
+
"seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
|
|
289
|
+
"seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
for name in self.param_names:
|
|
293
|
+
info[name]["dims"] = self.param_dims[name]
|
|
294
|
+
|
|
295
|
+
return {name: info[name] for name in self.param_names}
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def state_names(self):
|
|
299
|
+
if self.state_structure == "fast":
|
|
300
|
+
p, d, q = self.p, self.d, self.q
|
|
301
|
+
P, D, Q, S = self.P, self.D, self.Q, self.S
|
|
302
|
+
states = make_harvey_state_names(p, d, q, P, D, Q, S)
|
|
303
|
+
|
|
304
|
+
elif self.state_structure == "interpretable":
|
|
305
|
+
states = ["data"]
|
|
306
|
+
if self.p > 0:
|
|
307
|
+
states += [f"L{i + 1}.data" for i in range(self._p_max - 1)]
|
|
308
|
+
states += ["innovations"]
|
|
309
|
+
if self.q > 0:
|
|
310
|
+
states += [f"L{i + 1}.innovations" for i in range(self._q_max - 1)]
|
|
311
|
+
else:
|
|
312
|
+
raise NotImplementedError()
|
|
313
|
+
|
|
314
|
+
return states
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def observed_states(self):
|
|
318
|
+
return [self.state_names[0]]
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def shock_names(self):
|
|
322
|
+
return ["innovation"]
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def param_dims(self):
|
|
326
|
+
coord_map = {
|
|
327
|
+
"x0": (ALL_STATE_DIM,),
|
|
328
|
+
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
329
|
+
"sigma_obs": (OBS_STATE_DIM,),
|
|
330
|
+
"sigma_state": (OBS_STATE_DIM,),
|
|
331
|
+
"ar_params": (AR_PARAM_DIM,),
|
|
332
|
+
"ma_params": (MA_PARAM_DIM,),
|
|
333
|
+
"seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
|
|
334
|
+
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
|
|
335
|
+
}
|
|
336
|
+
if self.k_endog == 1:
|
|
337
|
+
coord_map["sigma_state"] = None
|
|
338
|
+
coord_map["sigma_obs"] = None
|
|
339
|
+
if not self.measurement_error:
|
|
340
|
+
del coord_map["sigma_obs"]
|
|
341
|
+
if self.p == 0:
|
|
342
|
+
del coord_map["ar_params"]
|
|
343
|
+
if self.q == 0:
|
|
344
|
+
del coord_map["ma_params"]
|
|
345
|
+
if self.P == 0:
|
|
346
|
+
del coord_map["seasonal_ar_params"]
|
|
347
|
+
if self.Q == 0:
|
|
348
|
+
del coord_map["seasonal_ma_params"]
|
|
349
|
+
if self.stationary_initialization:
|
|
350
|
+
del coord_map["P0"]
|
|
351
|
+
del coord_map["x0"]
|
|
352
|
+
|
|
353
|
+
return coord_map
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def coords(self) -> dict[str, Sequence]:
|
|
357
|
+
coords = make_default_coords(self)
|
|
358
|
+
if self.p > 0:
|
|
359
|
+
coords.update({AR_PARAM_DIM: list(range(1, self.p + 1))})
|
|
360
|
+
if self.q > 0:
|
|
361
|
+
coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
|
|
362
|
+
if self.P > 0:
|
|
363
|
+
coords.update({SEASONAL_AR_PARAM_DIM: list(range(1, self.P + 1))})
|
|
364
|
+
if self.Q > 0:
|
|
365
|
+
coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
|
|
366
|
+
|
|
367
|
+
return coords
|
|
368
|
+
|
|
369
|
+
def _stationary_initialization(self, mode=None):
|
|
370
|
+
# Solve for matrix quadratic for P0
|
|
371
|
+
T = self.ssm["transition"]
|
|
372
|
+
R = self.ssm["selection"]
|
|
373
|
+
Q = self.ssm["state_cov"]
|
|
374
|
+
c = self.ssm["state_intercept"]
|
|
375
|
+
|
|
376
|
+
x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
|
|
377
|
+
|
|
378
|
+
method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear"
|
|
379
|
+
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method)
|
|
380
|
+
|
|
381
|
+
return x0, P0
|
|
382
|
+
|
|
383
|
+
def make_symbolic_graph(self) -> None:
|
|
384
|
+
p, d, q = self.p, self.d, self.q
|
|
385
|
+
P, D, Q, S = self.P, self.D, self.Q, self.S
|
|
386
|
+
|
|
387
|
+
# Initial state and covariance can be handled first if we're not doing a stationary initialization
|
|
388
|
+
if not self.stationary_initialization:
|
|
389
|
+
x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)
|
|
390
|
+
P0 = self.make_and_register_variable(
|
|
391
|
+
"P0", shape=(self.k_states, self.k_states), dtype=floatX
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
self.ssm["initial_state", :] = x0
|
|
395
|
+
self.ssm["initial_state_cov"] = P0
|
|
396
|
+
|
|
397
|
+
# Design matrix has no RVs
|
|
398
|
+
k_lags = self.k_states - self._k_diffs
|
|
399
|
+
self.ssm["design"] = np.r_[[1] * d, ([0] * (S - 1) + [1]) * D, [1], [0] * (k_lags - 1)][
|
|
400
|
+
None
|
|
401
|
+
]
|
|
402
|
+
|
|
403
|
+
# Set up the transition and selection matrices, depending on the requested representation
|
|
404
|
+
if self.state_structure == "fast":
|
|
405
|
+
transition = make_SARIMA_transition_matrix(p, d, q, P, D, Q, S)
|
|
406
|
+
selection = np.r_[
|
|
407
|
+
[0] * self._k_diffs, [1.0], np.zeros(self.k_states - self._k_diffs - 1)
|
|
408
|
+
][:, None]
|
|
409
|
+
|
|
410
|
+
ar_param_idx = np.s_[
|
|
411
|
+
"transition", self._k_diffs : self._k_diffs + self.p, self._k_diffs
|
|
412
|
+
]
|
|
413
|
+
ma_param_idx = np.s_["selection", 1 + self._k_diffs : 1 + self._k_diffs + self.q, 0]
|
|
414
|
+
|
|
415
|
+
self.ssm["transition"] = transition
|
|
416
|
+
self.ssm["selection"] = selection
|
|
417
|
+
|
|
418
|
+
if p > 0:
|
|
419
|
+
ar_params = self.make_and_register_variable("ar_params", shape=(p,), dtype=floatX)
|
|
420
|
+
self.ssm[ar_param_idx] = ar_params
|
|
421
|
+
|
|
422
|
+
if P > 0:
|
|
423
|
+
seasonal_ar_params = self.make_and_register_variable(
|
|
424
|
+
"seasonal_ar_params", shape=(P,), dtype=floatX
|
|
425
|
+
)
|
|
426
|
+
idx_rows = self._k_diffs + (np.arange(1, P + 1) * S) - 1
|
|
427
|
+
S_ar_param_idx = np.s_["transition", idx_rows, self._k_diffs]
|
|
428
|
+
self.ssm[S_ar_param_idx] = seasonal_ar_params
|
|
429
|
+
|
|
430
|
+
if p > 0:
|
|
431
|
+
cross_term_idx = np.s_[
|
|
432
|
+
"transition",
|
|
433
|
+
idx_rows.repeat(p) + np.tile(np.arange(p), P) + 1,
|
|
434
|
+
self._k_diffs,
|
|
435
|
+
]
|
|
436
|
+
self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile(
|
|
437
|
+
ar_params, P
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if q > 0:
|
|
441
|
+
ma_params = self.make_and_register_variable("ma_params", shape=(q,), dtype=floatX)
|
|
442
|
+
self.ssm[ma_param_idx] = ma_params
|
|
443
|
+
|
|
444
|
+
if Q > 0:
|
|
445
|
+
seasonal_ma_params = self.make_and_register_variable(
|
|
446
|
+
"seasonal_ma_params", shape=(Q,), dtype=floatX
|
|
447
|
+
)
|
|
448
|
+
idx_rows = self._k_diffs + np.arange(1, Q + 1) * S
|
|
449
|
+
S_ma_param_idx = np.s_["selection", idx_rows, 0]
|
|
450
|
+
self.ssm[S_ma_param_idx] = seasonal_ma_params
|
|
451
|
+
|
|
452
|
+
if q > 0:
|
|
453
|
+
cross_term_idx = np.s_[
|
|
454
|
+
"selection", idx_rows.repeat(q) + np.tile(np.arange(q), Q) + 1, 0
|
|
455
|
+
]
|
|
456
|
+
self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile(
|
|
457
|
+
ma_params, Q
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
elif self.state_structure == "interpretable":
|
|
461
|
+
ar_param_idx = np.s_["transition", 0, : max(1, p)]
|
|
462
|
+
ma_param_idx = np.s_["transition", 0, self._p_max : self._p_max + max(1, q)]
|
|
463
|
+
|
|
464
|
+
transition = np.eye(self.k_states, k=-1)
|
|
465
|
+
transition[-self._q_max, self._p_max - 1] = 0
|
|
466
|
+
|
|
467
|
+
selection = np.r_[[1.0], np.zeros(self.k_states - 1)][:, None]
|
|
468
|
+
selection[-self._q_max, 0] = 1
|
|
469
|
+
|
|
470
|
+
self.ssm["transition"] = transition
|
|
471
|
+
self.ssm["selection"] = selection
|
|
472
|
+
|
|
473
|
+
if self.p > 0:
|
|
474
|
+
ar_params = self.make_and_register_variable(
|
|
475
|
+
"ar_params", shape=(self.p,), dtype=floatX
|
|
476
|
+
)
|
|
477
|
+
self.ssm[ar_param_idx] = ar_params
|
|
478
|
+
|
|
479
|
+
if self.P > 0:
|
|
480
|
+
seasonal_ar_params = self.make_and_register_variable(
|
|
481
|
+
"seasonal_ar_params", shape=(P,), dtype=floatX
|
|
482
|
+
)
|
|
483
|
+
idx_cols = np.arange(1, P + 1) * S - 1
|
|
484
|
+
S_ar_param_idx = np.s_["transition", 0, idx_cols]
|
|
485
|
+
self.ssm[S_ar_param_idx] = seasonal_ar_params
|
|
486
|
+
|
|
487
|
+
if p > 0:
|
|
488
|
+
cross_term_idx = np.s_[
|
|
489
|
+
"transition", 0, idx_cols.repeat(p) + np.tile(np.arange(p), P) + 1
|
|
490
|
+
]
|
|
491
|
+
self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile(
|
|
492
|
+
ar_params, P
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
if self.q > 0:
|
|
496
|
+
ma_params = self.make_and_register_variable(
|
|
497
|
+
"ma_params", shape=(self.q,), dtype=floatX
|
|
498
|
+
)
|
|
499
|
+
self.ssm[ma_param_idx] = ma_params
|
|
500
|
+
|
|
501
|
+
if Q > 0:
|
|
502
|
+
seasonal_ma_params = self.make_and_register_variable(
|
|
503
|
+
"seasonal_ma_params", shape=(Q,), dtype=floatX
|
|
504
|
+
)
|
|
505
|
+
idx_cols = self._p_max + np.arange(1, Q + 1) * S - 1
|
|
506
|
+
S_ma_param_idx = np.s_["transition", 0, idx_cols]
|
|
507
|
+
self.ssm[S_ma_param_idx] = seasonal_ma_params
|
|
508
|
+
|
|
509
|
+
if q > 0:
|
|
510
|
+
cross_term_idx = np.s_[
|
|
511
|
+
"transition", 0, idx_cols.repeat(q) + np.tile(np.arange(q), Q) + 1
|
|
512
|
+
]
|
|
513
|
+
self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile(
|
|
514
|
+
ma_params, Q
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Set up the state covariance matrix
|
|
518
|
+
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
|
|
519
|
+
state_cov = self.make_and_register_variable(
|
|
520
|
+
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
|
|
521
|
+
)
|
|
522
|
+
self.ssm[state_cov_idx] = state_cov**2
|
|
523
|
+
|
|
524
|
+
if self.measurement_error:
|
|
525
|
+
obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))
|
|
526
|
+
obs_cov = self.make_and_register_variable(
|
|
527
|
+
"sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
|
|
528
|
+
)
|
|
529
|
+
self.ssm[obs_cov_idx] = obs_cov**2
|
|
530
|
+
|
|
531
|
+
# The initial conditions have to be done last in the case of stationary initialization, because it will depend
|
|
532
|
+
# on c, T, R and Q
|
|
533
|
+
if self.stationary_initialization:
|
|
534
|
+
x0, P0 = self._stationary_initialization()
|
|
535
|
+
self.ssm["initial_state", :] = x0
|
|
536
|
+
self.ssm["initial_state_cov", :, :] = P0
|