pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,536 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import pytensor.tensor as pt
6
+
7
+ from pytensor.tensor.slinalg import solve_discrete_lyapunov
8
+
9
+ from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
10
+ from pymc_extras.statespace.models.utilities import (
11
+ make_default_coords,
12
+ make_harvey_state_names,
13
+ make_SARIMA_transition_matrix,
14
+ )
15
+ from pymc_extras.statespace.utils.constants import (
16
+ ALL_STATE_AUX_DIM,
17
+ ALL_STATE_DIM,
18
+ AR_PARAM_DIM,
19
+ MA_PARAM_DIM,
20
+ OBS_STATE_DIM,
21
+ SARIMAX_STATE_STRUCTURES,
22
+ SEASONAL_AR_PARAM_DIM,
23
+ SEASONAL_MA_PARAM_DIM,
24
+ )
25
+
26
+
27
+ def _verify_order(p, d, q, P, D, Q, S):
28
+ for name, terms in zip(["AR", "MA"], [(p, P), (q, Q)]):
29
+ a, A = terms
30
+ seasonal_lags = [(1 + i) * S for i in range(A)]
31
+ lags = [(1 + i) for i in range(a)]
32
+ overlapping_terms = set(seasonal_lags).intersection(set(lags))
33
+ if any(overlapping_terms):
34
+ raise ValueError(
35
+ f"The following {name} and seasonal {name} terms overlap, check model "
36
+ f"definition: {overlapping_terms}"
37
+ )
38
+
39
+
40
+ class BayesianSARIMA(PyMCStateSpace):
41
+ r"""
42
+ Seasonal AutoRegressive Integrated Moving Average with eXogenous regressors
43
+
44
+ Parameters
45
+ ----------
46
+ order: tuple(int, int, int)
47
+ Order of the ARIMA process. The order has the notation (p, d, q), where p is the number of autoregressive
48
+ lags, q is the number of moving average components, and d is order of integration -- the number of
49
+ differences needed to render the data stationary.
50
+
51
+ If d > 0, the differences are modeled as components of the hidden state, and all available data can be used.
52
+ This is only possible if state_structure = 'fast'. For interpretable states, the user must manually
53
+ difference the data prior to calling the `build_statespace_graph` method.
54
+
55
+ seasonal_order: tuple(int, int, int, int), optional
56
+ Seasonal order of the SARIMA process. The order has the notation (P, D, Q, S), where P is the number of seasonal
57
+ lags to include, Q is the number of seasonal innovation lags to include, and D is the number of seasonal
58
+ differences to perform. S is the length of the season.
59
+
60
+ Seasonal terms are similar to ARIMA terms, in that they are merely lags of the data or innovations. It is thus
61
+ possible for the seasonal lags and the ARIMA lags to overlap, for example if P <= p. In this case, an error
62
+ will be raised.
63
+
64
+ stationary_initialization: bool, default False
65
+ If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
66
+ state values will be used.
67
+
68
+ .. warning:: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics
69
+ for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite"
70
+ error.
71
+
72
+ filter_type: str, default "standard"
73
+ The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
74
+ and "cholesky". See the docs for kalman filters for more details.
75
+
76
+ state_structure: str, default "fast"
77
+ How to represent the state-space system. Currently, there are two choices: "fast" or "interpretable"
78
+
79
+ - "fast" corresponds to the state space used by [2], and is called the "Harvey" representation in statsmodels.
80
+ This is also the default representation used by statsmodels.tsa.statespace.SARIMAX. The states combine lags
81
+ and innovations at different lags to compress the dimension of the state vector to max(p, 1+q). As a result,
82
+ it is very preformat, but only the first state has a clear interpretation.
83
+
84
+ - "interpretable" maximally expands the state vector, doing zero state compression. As a result, the state has
85
+ dimension max(1, p) + max(1, q). What is gained by doing this is that every state has an obvious meaning, as
86
+ either the data, an innovation, or a lag thereof.
87
+
88
+ measurement_error: bool, default True
89
+ If true, a measurement error term is added to the model.
90
+
91
+ verbose: bool, default True
92
+ If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
93
+
94
+ Notes
95
+ -----
96
+ The ARIMAX model is a univariate time series model that posits the future evolution of a stationary time series will
97
+ be a function of its past values, together with exogenous "innovations" and their past history. The model is
98
+ described by its "order", a 3-tuple (p, d, q), that are:
99
+
100
+ - p: The number of past time steps that directly influence the present value of the time series, called the
101
+ "autoregressive", or AR, component
102
+ - d: The "integration" order of the time series
103
+ - q: The number of past exogenous innovations that directly influence the present value of the time series,
104
+ called the "moving average", or MA, component
105
+
106
+ Given this 3-tuple, the model can be written:
107
+
108
+ .. math::
109
+ (1- \phi_1 B - \cdots - \phi_p B^p) (1-B)^d y_{t} = c + (1 + \theta_1 B + \cdots + \theta_q B^q) \varepsilon_t
110
+
111
+ Where B is the backshift operator, :math:`By_{t} = y_{t-1}`.
112
+
113
+ The model assumes that the data are stationary; that is, that they can be described by a time-invariant Gaussian
114
+ distribution with fixed mean and finite variance. Non-stationary data, those that grow over time, are not suitable
115
+ for ARIMA modeling without preprocessing. Stationary can be induced in any time series by the sequential application
116
+ of differences. Given a hypothetical non-stationary process:
117
+
118
+ .. math::
119
+ y_{t} = c + \rho y_{t-1} + \varepsilon_{t}
120
+
121
+ The process:
122
+
123
+ .. math::
124
+ \Delta y_{t} = y_{t} - y_{t-1} = \rho \Delta y_{t-1} + \Delta \varepsilon_t
125
+
126
+ is stationary, as the non-stationary component :math:`c` was eliminated by the operation of differencing. This
127
+ process is said to be "integrated of order 1", as it requires 1 difference to render stationary. This is the
128
+ function of the `d` parameter in the ARIMA order.
129
+
130
+ Alternatively, the non-stationary components can be directly estimated. In this case, the errors of a preliminary
131
+ regression are assumed to be ARIMA distributed, so that:
132
+
133
+ .. math::
134
+ \begin{align}
135
+ y_{t} &= X\beta + \eta_t \\
136
+ (1- \phi_1 B - \cdots - \phi_p B^p) (1-B)^d \eta_{t} &= (1 + \theta_1 B + \cdots + \theta_q B^q) \varepsilon_t
137
+ \end{align}
138
+
139
+ Where the design matrix `X` can include a constant, trends, or exogenous regressors.
140
+
141
+ ARIMA models can be represented in statespace form, as described in [1]. For more details, see chapters 3.4, 3.6,
142
+ and 8.4.
143
+
144
+ Examples
145
+ --------
146
+ The following example shows how to build an ARMA(1, 1) model -- ARIMA(1, 0, 1) -- using the BayesianSARIMA class:
147
+
148
+ .. code:: python
149
+
150
+ import pymc_extras.statespace as pmss
151
+ import pymc as pm
152
+
153
+ ss_mod = pmss.BayesianSARIMA(order=(1, 0, 1), verbose=True)
154
+
155
+ with pm.Model(coords=ss_mod.coords) as arma_model:
156
+ state_sigmas = pm.HalfNormal("sigma_state", sigma=1.0, dims=ss_mod.param_dims["sigma_state"])
157
+
158
+ rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"])
159
+ theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"])
160
+
161
+ ss_mod.build_statespace_graph(df, mode="JAX")
162
+ idata = pm.sample(nuts_sampler='numpyro')
163
+
164
+ References
165
+ ----------
166
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
167
+ Time Series Analysis by State Space Methods: Second Edition.
168
+ Oxford University Press.
169
+
170
+ .. [2] Harvey, A. C. (1989). Forecasting, Structural Time Series Models and the
171
+ Kalman Filter. Cambridge: Cambridge University Press.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ order: tuple[int, int, int],
177
+ seasonal_order: tuple[int, int, int, int] | None = None,
178
+ stationary_initialization: bool = True,
179
+ filter_type: str = "standard",
180
+ state_structure: str = "fast",
181
+ measurement_error: bool = False,
182
+ verbose=True,
183
+ ):
184
+ # Model order
185
+ self.p, self.d, self.q = order
186
+ if seasonal_order is None:
187
+ seasonal_order = (0, 0, 0, 0)
188
+
189
+ self.P, self.D, self.Q, self.S = seasonal_order
190
+ _verify_order(self.p, self.d, self.q, self.P, self.D, self.Q, self.S)
191
+
192
+ self.stationary_initialization = stationary_initialization
193
+
194
+ self.state_structure = state_structure
195
+
196
+ self._p_max = max(1, self.p + self.P * self.S)
197
+ self._q_max = max(1, self.q + self.Q * self.S)
198
+
199
+ k_states = None
200
+ self._k_diffs = self.d + self.S * self.D
201
+
202
+ if state_structure not in SARIMAX_STATE_STRUCTURES:
203
+ raise ValueError(
204
+ f"Got invalid argument {state_structure} for state structure, expected one of "
205
+ f'{", ".join(SARIMAX_STATE_STRUCTURES)}'
206
+ )
207
+
208
+ if state_structure == "interpretable" and (self.d + self.D) > 0:
209
+ raise ValueError(
210
+ "Cannot use interpretable state structure with statespace differencing. Difference the "
211
+ 'data by hand (leaving NaN values to be interpolated), or use state_structure="fast"'
212
+ )
213
+
214
+ if self.state_structure == "fast":
215
+ k_states = max(self.p + self.P * self.S, self.q + self.Q * self.S + 1) + (
216
+ self.S * self.D + self.d
217
+ )
218
+ elif self.state_structure == "interpretable":
219
+ k_states = self._p_max + self._q_max
220
+
221
+ k_posdef = 1
222
+ k_endog = 1
223
+
224
+ super().__init__(
225
+ k_endog,
226
+ k_states,
227
+ k_posdef,
228
+ filter_type,
229
+ verbose=verbose,
230
+ measurement_error=measurement_error,
231
+ )
232
+
233
+ @property
234
+ def param_names(self):
235
+ names = [
236
+ "x0",
237
+ "P0",
238
+ "ar_params",
239
+ "ma_params",
240
+ "seasonal_ar_params",
241
+ "seasonal_ma_params",
242
+ "sigma_state",
243
+ "sigma_obs",
244
+ ]
245
+ if self.stationary_initialization:
246
+ names.remove("P0")
247
+ names.remove("x0")
248
+ if self.p == 0:
249
+ names.remove("ar_params")
250
+ if self.P == 0:
251
+ names.remove("seasonal_ar_params")
252
+ if self.q == 0:
253
+ names.remove("ma_params")
254
+ if self.Q == 0:
255
+ names.remove("seasonal_ma_params")
256
+ if not self.measurement_error:
257
+ names.remove("sigma_obs")
258
+
259
+ return names
260
+
261
+ @property
262
+ def param_info(self) -> dict[str, dict[str, Any]]:
263
+ info = {
264
+ "x0": {
265
+ "shape": (self.k_states,),
266
+ "constraints": None,
267
+ },
268
+ "P0": {
269
+ "shape": (self.k_states, self.k_states),
270
+ "constraints": "Positive Semi-definite",
271
+ },
272
+ "sigma_obs": {
273
+ "shape": None if self.k_endog == 1 else (self.k_endog,),
274
+ "constraints": "Positive",
275
+ },
276
+ "sigma_state": {
277
+ "shape": None if self.k_posdef == 1 else (self.k_posdef,),
278
+ "constraints": "Positive",
279
+ },
280
+ "ar_params": {
281
+ "shape": (self.p,),
282
+ "constraints": "None",
283
+ },
284
+ "ma_params": {
285
+ "shape": (self.q,),
286
+ "constraints": "None",
287
+ },
288
+ "seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
289
+ "seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
290
+ }
291
+
292
+ for name in self.param_names:
293
+ info[name]["dims"] = self.param_dims[name]
294
+
295
+ return {name: info[name] for name in self.param_names}
296
+
297
+ @property
298
+ def state_names(self):
299
+ if self.state_structure == "fast":
300
+ p, d, q = self.p, self.d, self.q
301
+ P, D, Q, S = self.P, self.D, self.Q, self.S
302
+ states = make_harvey_state_names(p, d, q, P, D, Q, S)
303
+
304
+ elif self.state_structure == "interpretable":
305
+ states = ["data"]
306
+ if self.p > 0:
307
+ states += [f"L{i + 1}.data" for i in range(self._p_max - 1)]
308
+ states += ["innovations"]
309
+ if self.q > 0:
310
+ states += [f"L{i + 1}.innovations" for i in range(self._q_max - 1)]
311
+ else:
312
+ raise NotImplementedError()
313
+
314
+ return states
315
+
316
+ @property
317
+ def observed_states(self):
318
+ return [self.state_names[0]]
319
+
320
+ @property
321
+ def shock_names(self):
322
+ return ["innovation"]
323
+
324
+ @property
325
+ def param_dims(self):
326
+ coord_map = {
327
+ "x0": (ALL_STATE_DIM,),
328
+ "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
329
+ "sigma_obs": (OBS_STATE_DIM,),
330
+ "sigma_state": (OBS_STATE_DIM,),
331
+ "ar_params": (AR_PARAM_DIM,),
332
+ "ma_params": (MA_PARAM_DIM,),
333
+ "seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
334
+ "seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
335
+ }
336
+ if self.k_endog == 1:
337
+ coord_map["sigma_state"] = None
338
+ coord_map["sigma_obs"] = None
339
+ if not self.measurement_error:
340
+ del coord_map["sigma_obs"]
341
+ if self.p == 0:
342
+ del coord_map["ar_params"]
343
+ if self.q == 0:
344
+ del coord_map["ma_params"]
345
+ if self.P == 0:
346
+ del coord_map["seasonal_ar_params"]
347
+ if self.Q == 0:
348
+ del coord_map["seasonal_ma_params"]
349
+ if self.stationary_initialization:
350
+ del coord_map["P0"]
351
+ del coord_map["x0"]
352
+
353
+ return coord_map
354
+
355
+ @property
356
+ def coords(self) -> dict[str, Sequence]:
357
+ coords = make_default_coords(self)
358
+ if self.p > 0:
359
+ coords.update({AR_PARAM_DIM: list(range(1, self.p + 1))})
360
+ if self.q > 0:
361
+ coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
362
+ if self.P > 0:
363
+ coords.update({SEASONAL_AR_PARAM_DIM: list(range(1, self.P + 1))})
364
+ if self.Q > 0:
365
+ coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
366
+
367
+ return coords
368
+
369
+ def _stationary_initialization(self, mode=None):
370
+ # Solve for matrix quadratic for P0
371
+ T = self.ssm["transition"]
372
+ R = self.ssm["selection"]
373
+ Q = self.ssm["state_cov"]
374
+ c = self.ssm["state_intercept"]
375
+
376
+ x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
377
+
378
+ method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear"
379
+ P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method)
380
+
381
+ return x0, P0
382
+
383
+ def make_symbolic_graph(self) -> None:
384
+ p, d, q = self.p, self.d, self.q
385
+ P, D, Q, S = self.P, self.D, self.Q, self.S
386
+
387
+ # Initial state and covariance can be handled first if we're not doing a stationary initialization
388
+ if not self.stationary_initialization:
389
+ x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)
390
+ P0 = self.make_and_register_variable(
391
+ "P0", shape=(self.k_states, self.k_states), dtype=floatX
392
+ )
393
+
394
+ self.ssm["initial_state", :] = x0
395
+ self.ssm["initial_state_cov"] = P0
396
+
397
+ # Design matrix has no RVs
398
+ k_lags = self.k_states - self._k_diffs
399
+ self.ssm["design"] = np.r_[[1] * d, ([0] * (S - 1) + [1]) * D, [1], [0] * (k_lags - 1)][
400
+ None
401
+ ]
402
+
403
+ # Set up the transition and selection matrices, depending on the requested representation
404
+ if self.state_structure == "fast":
405
+ transition = make_SARIMA_transition_matrix(p, d, q, P, D, Q, S)
406
+ selection = np.r_[
407
+ [0] * self._k_diffs, [1.0], np.zeros(self.k_states - self._k_diffs - 1)
408
+ ][:, None]
409
+
410
+ ar_param_idx = np.s_[
411
+ "transition", self._k_diffs : self._k_diffs + self.p, self._k_diffs
412
+ ]
413
+ ma_param_idx = np.s_["selection", 1 + self._k_diffs : 1 + self._k_diffs + self.q, 0]
414
+
415
+ self.ssm["transition"] = transition
416
+ self.ssm["selection"] = selection
417
+
418
+ if p > 0:
419
+ ar_params = self.make_and_register_variable("ar_params", shape=(p,), dtype=floatX)
420
+ self.ssm[ar_param_idx] = ar_params
421
+
422
+ if P > 0:
423
+ seasonal_ar_params = self.make_and_register_variable(
424
+ "seasonal_ar_params", shape=(P,), dtype=floatX
425
+ )
426
+ idx_rows = self._k_diffs + (np.arange(1, P + 1) * S) - 1
427
+ S_ar_param_idx = np.s_["transition", idx_rows, self._k_diffs]
428
+ self.ssm[S_ar_param_idx] = seasonal_ar_params
429
+
430
+ if p > 0:
431
+ cross_term_idx = np.s_[
432
+ "transition",
433
+ idx_rows.repeat(p) + np.tile(np.arange(p), P) + 1,
434
+ self._k_diffs,
435
+ ]
436
+ self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile(
437
+ ar_params, P
438
+ )
439
+
440
+ if q > 0:
441
+ ma_params = self.make_and_register_variable("ma_params", shape=(q,), dtype=floatX)
442
+ self.ssm[ma_param_idx] = ma_params
443
+
444
+ if Q > 0:
445
+ seasonal_ma_params = self.make_and_register_variable(
446
+ "seasonal_ma_params", shape=(Q,), dtype=floatX
447
+ )
448
+ idx_rows = self._k_diffs + np.arange(1, Q + 1) * S
449
+ S_ma_param_idx = np.s_["selection", idx_rows, 0]
450
+ self.ssm[S_ma_param_idx] = seasonal_ma_params
451
+
452
+ if q > 0:
453
+ cross_term_idx = np.s_[
454
+ "selection", idx_rows.repeat(q) + np.tile(np.arange(q), Q) + 1, 0
455
+ ]
456
+ self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile(
457
+ ma_params, Q
458
+ )
459
+
460
+ elif self.state_structure == "interpretable":
461
+ ar_param_idx = np.s_["transition", 0, : max(1, p)]
462
+ ma_param_idx = np.s_["transition", 0, self._p_max : self._p_max + max(1, q)]
463
+
464
+ transition = np.eye(self.k_states, k=-1)
465
+ transition[-self._q_max, self._p_max - 1] = 0
466
+
467
+ selection = np.r_[[1.0], np.zeros(self.k_states - 1)][:, None]
468
+ selection[-self._q_max, 0] = 1
469
+
470
+ self.ssm["transition"] = transition
471
+ self.ssm["selection"] = selection
472
+
473
+ if self.p > 0:
474
+ ar_params = self.make_and_register_variable(
475
+ "ar_params", shape=(self.p,), dtype=floatX
476
+ )
477
+ self.ssm[ar_param_idx] = ar_params
478
+
479
+ if self.P > 0:
480
+ seasonal_ar_params = self.make_and_register_variable(
481
+ "seasonal_ar_params", shape=(P,), dtype=floatX
482
+ )
483
+ idx_cols = np.arange(1, P + 1) * S - 1
484
+ S_ar_param_idx = np.s_["transition", 0, idx_cols]
485
+ self.ssm[S_ar_param_idx] = seasonal_ar_params
486
+
487
+ if p > 0:
488
+ cross_term_idx = np.s_[
489
+ "transition", 0, idx_cols.repeat(p) + np.tile(np.arange(p), P) + 1
490
+ ]
491
+ self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile(
492
+ ar_params, P
493
+ )
494
+
495
+ if self.q > 0:
496
+ ma_params = self.make_and_register_variable(
497
+ "ma_params", shape=(self.q,), dtype=floatX
498
+ )
499
+ self.ssm[ma_param_idx] = ma_params
500
+
501
+ if Q > 0:
502
+ seasonal_ma_params = self.make_and_register_variable(
503
+ "seasonal_ma_params", shape=(Q,), dtype=floatX
504
+ )
505
+ idx_cols = self._p_max + np.arange(1, Q + 1) * S - 1
506
+ S_ma_param_idx = np.s_["transition", 0, idx_cols]
507
+ self.ssm[S_ma_param_idx] = seasonal_ma_params
508
+
509
+ if q > 0:
510
+ cross_term_idx = np.s_[
511
+ "transition", 0, idx_cols.repeat(q) + np.tile(np.arange(q), Q) + 1
512
+ ]
513
+ self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile(
514
+ ma_params, Q
515
+ )
516
+
517
+ # Set up the state covariance matrix
518
+ state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
519
+ state_cov = self.make_and_register_variable(
520
+ "sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
521
+ )
522
+ self.ssm[state_cov_idx] = state_cov**2
523
+
524
+ if self.measurement_error:
525
+ obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))
526
+ obs_cov = self.make_and_register_variable(
527
+ "sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
528
+ )
529
+ self.ssm[obs_cov_idx] = obs_cov**2
530
+
531
+ # The initial conditions have to be done last in the case of stationary initialization, because it will depend
532
+ # on c, T, R and Q
533
+ if self.stationary_initialization:
534
+ x0, P0 = self._stationary_initialization()
535
+ self.ssm["initial_state", :] = x0
536
+ self.ssm["initial_state_cov", :, :] = P0