pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,1651 @@
1
+ import functools as ft
2
+ import logging
3
+
4
+ from abc import ABC
5
+ from collections.abc import Sequence
6
+ from itertools import pairwise
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import pytensor
11
+ import pytensor.tensor as pt
12
+ import xarray as xr
13
+
14
+ from pytensor import Variable
15
+
16
+ from pymc_extras.statespace.core import PytensorRepresentation
17
+ from pymc_extras.statespace.core.statespace import PyMCStateSpace
18
+ from pymc_extras.statespace.models.utilities import (
19
+ conform_time_varying_and_time_invariant_matrices,
20
+ make_default_coords,
21
+ )
22
+ from pymc_extras.statespace.utils.constants import (
23
+ ALL_STATE_AUX_DIM,
24
+ ALL_STATE_DIM,
25
+ AR_PARAM_DIM,
26
+ LONG_MATRIX_NAMES,
27
+ POSITION_DERIVATIVE_NAMES,
28
+ TIME_DIM,
29
+ )
30
+
31
+ _log = logging.getLogger("pymc.experimental.statespace")
32
+
33
+ floatX = pytensor.config.floatX
34
+
35
+
36
+ def order_to_mask(order):
37
+ if isinstance(order, int):
38
+ return np.ones(order).astype(bool)
39
+ else:
40
+ return np.array(order).astype(bool)
41
+
42
+
43
+ def _frequency_transition_block(s, j):
44
+ lam = 2 * np.pi * j / s
45
+
46
+ return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]])
47
+
48
+
49
+ class StructuralTimeSeries(PyMCStateSpace):
50
+ r"""
51
+ Structural Time Series Model
52
+
53
+ The structural time series model, named by [1] and presented in statespace form in [2], is a framework for
54
+ decomposing a univariate time series into level, trend, seasonal, and cycle components. It also admits the
55
+ possibility of exogenous regressors. Unlike the SARIMAX framework, the time series is not assumed to be stationary.
56
+
57
+ Notes
58
+ -----
59
+
60
+ .. math::
61
+ y_t = \mu_t + \gamma_t + c_t + \varepsilon_t
62
+
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ ssm: PytensorRepresentation,
68
+ state_names: list[str],
69
+ data_names: list[str],
70
+ shock_names: list[str],
71
+ param_names: list[str],
72
+ exog_names: list[str],
73
+ param_dims: dict[str, tuple[int]],
74
+ coords: dict[str, Sequence],
75
+ param_info: dict[str, dict[str, Any]],
76
+ data_info: dict[str, dict[str, Any]],
77
+ component_info: dict[str, dict[str, Any]],
78
+ measurement_error: bool,
79
+ name_to_variable: dict[str, Variable],
80
+ name_to_data: dict[str, Variable] | None = None,
81
+ name: str | None = None,
82
+ verbose: bool = True,
83
+ filter_type: str = "standard",
84
+ ):
85
+ # Add the initial state covariance to the parameters
86
+ if name is None:
87
+ name = "data"
88
+ self._name = name
89
+
90
+ k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog
91
+ param_names, param_dims, param_info = self._add_inital_state_cov_to_properties(
92
+ param_names, param_dims, param_info, k_states
93
+ )
94
+ self._state_names = state_names.copy()
95
+ self._data_names = data_names.copy()
96
+ self._shock_names = shock_names.copy()
97
+ self._param_names = param_names.copy()
98
+ self._param_dims = param_dims.copy()
99
+
100
+ default_coords = make_default_coords(self)
101
+ coords.update(default_coords)
102
+
103
+ self._coords = coords
104
+ self._param_info = param_info.copy()
105
+ self._data_info = data_info.copy()
106
+ self.measurement_error = measurement_error
107
+
108
+ super().__init__(
109
+ k_endog,
110
+ k_states,
111
+ max(1, k_posdef),
112
+ filter_type=filter_type,
113
+ verbose=verbose,
114
+ measurement_error=measurement_error,
115
+ )
116
+ self.ssm = ssm.copy()
117
+
118
+ if k_posdef == 0:
119
+ # If there is no randomness in the model, add dummy matrices to the representation to avoid errors
120
+ # when we go to construct random variables from the matrices
121
+ self.ssm.k_posdef = self.k_posdef
122
+ self.ssm.shapes["state_cov"] = (1, 1, 1)
123
+ self.ssm["state_cov"] = pt.zeros((1, 1, 1))
124
+
125
+ self.ssm.shapes["selection"] = (1, self.k_states, 1)
126
+ self.ssm["selection"] = pt.zeros((1, self.k_states, 1))
127
+
128
+ self._component_info = component_info.copy()
129
+
130
+ self._name_to_variable = name_to_variable.copy()
131
+ self._name_to_data = name_to_data.copy()
132
+
133
+ self._exog_names = exog_names.copy()
134
+ self._needs_exog_data = len(exog_names) > 0
135
+
136
+ P0 = self.make_and_register_variable("P0", shape=(self.k_states, self.k_states))
137
+ self.ssm["initial_state_cov"] = P0
138
+
139
+ @staticmethod
140
+ def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states):
141
+ param_names += ["P0"]
142
+ param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM)
143
+ param_info["P0"] = {
144
+ "shape": (k_states, k_states),
145
+ "constraints": "Positive semi-definite",
146
+ "dims": param_dims["P0"],
147
+ }
148
+
149
+ return param_names, param_dims, param_info
150
+
151
+ @property
152
+ def param_names(self):
153
+ return self._param_names
154
+
155
+ @property
156
+ def data_names(self) -> list[str]:
157
+ return self._data_names
158
+
159
+ @property
160
+ def state_names(self):
161
+ return self._state_names
162
+
163
+ @property
164
+ def observed_states(self):
165
+ return [self._name]
166
+
167
+ @property
168
+ def shock_names(self):
169
+ return self._shock_names
170
+
171
+ @property
172
+ def param_dims(self):
173
+ return self._param_dims
174
+
175
+ @property
176
+ def coords(self) -> dict[str, Sequence]:
177
+ return self._coords
178
+
179
+ @property
180
+ def param_info(self) -> dict[str, dict[str, Any]]:
181
+ return self._param_info
182
+
183
+ @property
184
+ def data_info(self) -> dict[str, dict[str, Any]]:
185
+ return self._data_info
186
+
187
+ def make_symbolic_graph(self) -> None:
188
+ """
189
+ Assign placeholder pytensor variables among statespace matrices in positions where PyMC variables will go.
190
+
191
+ Notes
192
+ -----
193
+ This assignment is handled by the components, so this function is implemented only to avoid the
194
+ NotImplementedError raised by the base class.
195
+ """
196
+
197
+ pass
198
+
199
+ def _state_slices_from_info(self):
200
+ info = self._component_info.copy()
201
+ comp_states = np.cumsum([0] + [info["k_states"] for info in info.values()])
202
+ state_slices = [slice(i, j) for i, j in pairwise(comp_states)]
203
+
204
+ return state_slices
205
+
206
+ def _hidden_states_from_data(self, data):
207
+ state_slices = self._state_slices_from_info()
208
+ info = self._component_info
209
+ names = info.keys()
210
+ result = []
211
+
212
+ for i, (name, s) in enumerate(zip(names, state_slices)):
213
+ obs_idx = info[name]["obs_state_idx"]
214
+ if obs_idx is None:
215
+ continue
216
+
217
+ X = data[..., s]
218
+ if info[name]["combine_hidden_states"]:
219
+ sum_idx = np.flatnonzero(obs_idx)
220
+ result.append(X[..., sum_idx].sum(axis=-1)[..., None])
221
+ else:
222
+ comp_names = self.state_names[s]
223
+ for j, state_name in enumerate(comp_names):
224
+ result.append(X[..., j, None])
225
+
226
+ return np.concatenate(result, axis=-1)
227
+
228
+ def _get_subcomponent_names(self):
229
+ state_slices = self._state_slices_from_info()
230
+ info = self._component_info
231
+ names = info.keys()
232
+ result = []
233
+
234
+ for i, (name, s) in enumerate(zip(names, state_slices)):
235
+ if info[name]["combine_hidden_states"]:
236
+ result.append(name)
237
+ else:
238
+ comp_names = self.state_names[s]
239
+ result.extend([f"{name}[{comp_name}]" for comp_name in comp_names])
240
+ return result
241
+
242
+ def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset:
243
+ r"""
244
+ Extract interpretable hidden states from an InferenceData returned by a PyMCStateSpace sampling method
245
+
246
+ Parameters
247
+ ----------
248
+ idata: Dataset
249
+ A Dataset object, returned by a PyMCStateSpace sampling method
250
+
251
+ Returns
252
+ -------
253
+ idata: Dataset
254
+ An Dataset object with hidden states transformed to represent only the "interpretable" subcomponents
255
+ of the structural model.
256
+
257
+ Notes
258
+ -----
259
+ In general, a structural statespace model can be represented as:
260
+
261
+ .. math::
262
+ y_t = \mu_t + \nu_t + \cdots + \gamma_t + c_t + \xi_t + \epsilon_t \tag{1}
263
+
264
+ Where:
265
+
266
+ - :math:`\mu_t` is the level of the data at time t
267
+ - :math:`\nu_t` is the slope of the data at time t
268
+ - :math:`\cdots` are higher time derivatives of the position (acceleration, jerk, etc) at time t
269
+ - :math:`\gamma_t` is the seasonal component at time t
270
+ - :math:`c_t` is the cycle component at time t
271
+ - :math:`\xi_t` is the autoregressive error at time t
272
+ - :math:`\varepsilon_t` is the measurement error at time t
273
+
274
+ In state space form, some or all of these components are represented as linear combinations of other
275
+ subcomponents, making interpretation of the outputs of the outputs difficult. The purpose of this function is
276
+ to take the expended statespace representation and return a "reduced form" of only the components shown in
277
+ equation (1).
278
+ """
279
+
280
+ def _extract_and_transform_variable(idata, new_state_names):
281
+ *_, time_dim, state_dim = idata.dims
282
+ state_func = ft.partial(self._hidden_states_from_data)
283
+ new_idata = xr.apply_ufunc(
284
+ state_func,
285
+ idata,
286
+ input_core_dims=[[time_dim, state_dim]],
287
+ output_core_dims=[[time_dim, state_dim]],
288
+ exclude_dims={state_dim},
289
+ )
290
+ new_idata.coords.update({state_dim: new_state_names})
291
+ return new_idata
292
+
293
+ var_names = list(idata.data_vars.keys())
294
+ is_latent = [idata[name].shape[-1] == self.k_states for name in var_names]
295
+ new_state_names = self._get_subcomponent_names()
296
+
297
+ latent_names = [name for latent, name in zip(is_latent, var_names) if latent]
298
+ dropped_vars = set(var_names) - set(latent_names)
299
+ if len(dropped_vars) > 0:
300
+ _log.warning(
301
+ f'Variables {", ".join(dropped_vars)} do not contain all hidden states (their last dimension '
302
+ f"is not {self.k_states}). They will not be present in the modified idata."
303
+ )
304
+ if len(dropped_vars) == len(var_names):
305
+ raise ValueError(
306
+ "Provided idata had no variables with all hidden states; cannot extract components."
307
+ )
308
+
309
+ idata_new = xr.Dataset(
310
+ {
311
+ name: _extract_and_transform_variable(idata[name], new_state_names)
312
+ for name in latent_names
313
+ }
314
+ )
315
+ return idata_new
316
+
317
+
318
+ class Component(ABC):
319
+ r"""
320
+ Base class for a component of a structural timeseries model.
321
+
322
+ This base class contains a subset of the class attributes of the PyMCStateSpace class, and none of the class
323
+ methods. The purpose of a component is to allow the partial definition of a structural model. Components are
324
+ assembled into a full model by the StructuralTimeSeries class.
325
+
326
+ Parameters
327
+ ----------
328
+ name: str
329
+ The name of the component
330
+ k_endog: int
331
+ Number of endogenous variables being modeled. Currently, must be one because structural models only support
332
+ univariate data.
333
+ k_states: int
334
+ Number of hidden states in the component model
335
+ k_posdef: int
336
+ Rank of the state covariance matrix, or the number of sources of innovations in the component model
337
+ measurement_error: bool
338
+ Whether the observation associated with the component has measurement error. Default is False.
339
+ combine_hidden_states: bool
340
+ Flag for the ``extract_hidden_states_from_data`` method. When ``True``, hidden states from the component model
341
+ are extracted as ``hidden_states[:, np.flatnonzero(Z)]``. Should be True in models where hidden states
342
+ individually have no interpretation, such as seasonal or autoregressive components.
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ name,
348
+ k_endog,
349
+ k_states,
350
+ k_posdef,
351
+ state_names=None,
352
+ data_names=None,
353
+ shock_names=None,
354
+ param_names=None,
355
+ exog_names=None,
356
+ representation: PytensorRepresentation | None = None,
357
+ measurement_error=False,
358
+ combine_hidden_states=True,
359
+ component_from_sum=False,
360
+ obs_state_idxs=None,
361
+ ):
362
+ self.name = name
363
+ self.k_endog = k_endog
364
+ self.k_states = k_states
365
+ self.k_posdef = k_posdef
366
+ self.measurement_error = measurement_error
367
+
368
+ self.state_names = state_names if state_names is not None else []
369
+ self.data_names = data_names if data_names is not None else []
370
+ self.shock_names = shock_names if shock_names is not None else []
371
+ self.param_names = param_names if param_names is not None else []
372
+ self.exog_names = exog_names if exog_names is not None else []
373
+
374
+ self.needs_exog_data = len(self.exog_names) > 0
375
+ self.coords = {}
376
+ self.param_dims = {}
377
+
378
+ self.param_info = {}
379
+ self.data_info = {}
380
+
381
+ self.param_counts = {}
382
+
383
+ if representation is None:
384
+ self.ssm = PytensorRepresentation(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)
385
+ else:
386
+ self.ssm = representation
387
+
388
+ self._name_to_variable = {}
389
+ self._name_to_data = {}
390
+
391
+ if not component_from_sum:
392
+ self.populate_component_properties()
393
+ self.make_symbolic_graph()
394
+
395
+ self._component_info = {
396
+ self.name: {
397
+ "k_states": self.k_states,
398
+ "k_enodg": self.k_endog,
399
+ "k_posdef": self.k_posdef,
400
+ "combine_hidden_states": combine_hidden_states,
401
+ "obs_state_idx": obs_state_idxs,
402
+ }
403
+ }
404
+
405
+ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
406
+ r"""
407
+ Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
408
+
409
+ Parameters
410
+ ----------
411
+ name : str
412
+ The name of the placeholder variable. Must be the name of a model parameter.
413
+ shape : int or tuple of int
414
+ Shape of the parameter
415
+ dtype : str, default pytensor.config.floatX
416
+ dtype of the parameter
417
+
418
+ Notes
419
+ -----
420
+ Symbolic pytensor variables are used in the ``make_symbolic_graph`` method as placeholders for PyMC random
421
+ variables. The change is made in the ``_insert_random_variables`` method via ``pytensor.graph_replace``. To
422
+ make the change, a dictionary mapping pytensor variables to PyMC random variables needs to be constructed.
423
+
424
+ The purpose of this method is to:
425
+ 1. Create the placeholder symbolic variables
426
+ 2. Register the placeholder variable in the ``_name_to_variable`` dictionary
427
+
428
+ The shape provided here will define the shape of the prior that will need to be provided by the user.
429
+
430
+ An error is raised if the provided name has already been registered, or if the name is not present in the
431
+ ``param_names`` property.
432
+ """
433
+ if name not in self.param_names:
434
+ raise ValueError(
435
+ f"{name} is not a model parameter. All placeholder variables should correspond to model "
436
+ f"parameters."
437
+ )
438
+
439
+ if name in self._name_to_variable.keys():
440
+ raise ValueError(
441
+ f"{name} is already a registered placeholder variable with shape "
442
+ f"{self._name_to_variable[name].type.shape}"
443
+ )
444
+
445
+ placeholder = pt.tensor(name, shape=shape, dtype=dtype)
446
+ self._name_to_variable[name] = placeholder
447
+ return placeholder
448
+
449
+ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable:
450
+ r"""
451
+ Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
452
+
453
+ Parameters
454
+ ----------
455
+ name : str
456
+ The name of the placeholder data. Must be the name of an expected data variable.
457
+ shape : int or tuple of int
458
+ Shape of the parameter
459
+ dtype : str, default pytensor.config.floatX
460
+ dtype of the parameter
461
+
462
+ Notes
463
+ -----
464
+ See docstring for make_and_register_variable for more details. This function is similar, but handles data
465
+ inputs instead of model parameters.
466
+
467
+ An error is raised if the provided name has already been registered, or if the name is not present in the
468
+ ``data_names`` property.
469
+ """
470
+ if name not in self.data_names:
471
+ raise ValueError(
472
+ f"{name} is not a model parameter. All placeholder variables should correspond to model "
473
+ f"parameters."
474
+ )
475
+
476
+ if name in self._name_to_data.keys():
477
+ raise ValueError(
478
+ f"{name} is already a registered placeholder variable with shape "
479
+ f"{self._name_to_data[name].type.shape}"
480
+ )
481
+
482
+ placeholder = pt.tensor(name, shape=shape, dtype=dtype)
483
+ self._name_to_data[name] = placeholder
484
+ return placeholder
485
+
486
+ def make_symbolic_graph(self) -> None:
487
+ raise NotImplementedError
488
+
489
+ def populate_component_properties(self):
490
+ raise NotImplementedError
491
+
492
+ def _get_combined_shapes(self, other):
493
+ k_states = self.k_states + other.k_states
494
+ k_posdef = self.k_posdef + other.k_posdef
495
+ if self.k_endog != other.k_endog:
496
+ raise NotImplementedError(
497
+ "Merging elements with different numbers of observed states is not supported.>"
498
+ )
499
+ k_endog = self.k_endog
500
+
501
+ return k_states, k_posdef, k_endog
502
+
503
+ def _combine_statespace_representations(self, other):
504
+ def make_slice(name, x, o_x):
505
+ ndim = max(x.ndim, o_x.ndim)
506
+ return (name,) + (slice(None, None, None),) * ndim
507
+
508
+ k_states, k_posdef, k_endog = self._get_combined_shapes(other)
509
+
510
+ self_matrices = [self.ssm[name] for name in LONG_MATRIX_NAMES]
511
+ other_matrices = [other.ssm[name] for name in LONG_MATRIX_NAMES]
512
+
513
+ x0, P0, c, d, T, Z, R, H, Q = (
514
+ self.ssm[make_slice(name, x, o_x)]
515
+ for name, x, o_x in zip(LONG_MATRIX_NAMES, self_matrices, other_matrices)
516
+ )
517
+ o_x0, o_P0, o_c, o_d, o_T, o_Z, o_R, o_H, o_Q = (
518
+ other.ssm[make_slice(name, x, o_x)]
519
+ for name, x, o_x in zip(LONG_MATRIX_NAMES, self_matrices, other_matrices)
520
+ )
521
+
522
+ initial_state = pt.concatenate(conform_time_varying_and_time_invariant_matrices(x0, o_x0))
523
+ initial_state.name = x0.name
524
+
525
+ initial_state_cov = pt.linalg.block_diag(P0, o_P0)
526
+ initial_state_cov.name = P0.name
527
+
528
+ state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c))
529
+ state_intercept.name = c.name
530
+
531
+ obs_intercept = d + o_d
532
+ obs_intercept.name = d.name
533
+
534
+ transition = pt.linalg.block_diag(T, o_T)
535
+ transition.name = T.name
536
+
537
+ design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1)
538
+ design.name = Z.name
539
+
540
+ selection = pt.linalg.block_diag(R, o_R)
541
+ selection.name = R.name
542
+
543
+ obs_cov = H + o_H
544
+ obs_cov.name = H.name
545
+
546
+ state_cov = pt.linalg.block_diag(Q, o_Q)
547
+ state_cov.name = Q.name
548
+
549
+ new_ssm = PytensorRepresentation(
550
+ k_endog=k_endog,
551
+ k_states=k_states,
552
+ k_posdef=k_posdef,
553
+ initial_state=initial_state,
554
+ initial_state_cov=initial_state_cov,
555
+ state_intercept=state_intercept,
556
+ obs_intercept=obs_intercept,
557
+ transition=transition,
558
+ design=design,
559
+ selection=selection,
560
+ obs_cov=obs_cov,
561
+ state_cov=state_cov,
562
+ )
563
+
564
+ return new_ssm
565
+
566
+ def _combine_property(self, other, name):
567
+ self_prop = getattr(self, name)
568
+ if isinstance(self_prop, list):
569
+ return self_prop + getattr(other, name)
570
+ elif isinstance(self_prop, dict):
571
+ new_prop = self_prop.copy()
572
+ new_prop.update(getattr(other, name))
573
+ return new_prop
574
+
575
+ def _combine_component_info(self, other):
576
+ combined_info = {}
577
+ for key, value in self._component_info.items():
578
+ if not key.startswith("StateSpace"):
579
+ if key in combined_info.keys():
580
+ raise ValueError(f"Found duplicate component named {key}")
581
+ combined_info[key] = value
582
+
583
+ for key, value in other._component_info.items():
584
+ if not key.startswith("StateSpace"):
585
+ if key in combined_info.keys():
586
+ raise ValueError(f"Found duplicate component named {key}")
587
+ combined_info[key] = value
588
+
589
+ return combined_info
590
+
591
+ def _make_combined_name(self):
592
+ components = self._component_info.keys()
593
+ name = f'StateSpace[{", ".join(components)}]'
594
+ return name
595
+
596
+ def __add__(self, other):
597
+ state_names = self._combine_property(other, "state_names")
598
+ data_names = self._combine_property(other, "data_names")
599
+ param_names = self._combine_property(other, "param_names")
600
+ shock_names = self._combine_property(other, "shock_names")
601
+ param_info = self._combine_property(other, "param_info")
602
+ data_info = self._combine_property(other, "data_info")
603
+ param_dims = self._combine_property(other, "param_dims")
604
+ coords = self._combine_property(other, "coords")
605
+ exog_names = self._combine_property(other, "exog_names")
606
+
607
+ _name_to_variable = self._combine_property(other, "_name_to_variable")
608
+ _name_to_data = self._combine_property(other, "_name_to_data")
609
+
610
+ measurement_error = any([self.measurement_error, other.measurement_error])
611
+
612
+ k_states, k_posdef, k_endog = self._get_combined_shapes(other)
613
+ ssm = self._combine_statespace_representations(other)
614
+
615
+ new_comp = Component(
616
+ name="",
617
+ k_endog=1,
618
+ k_states=k_states,
619
+ k_posdef=k_posdef,
620
+ measurement_error=measurement_error,
621
+ representation=ssm,
622
+ component_from_sum=True,
623
+ )
624
+ new_comp._component_info = self._combine_component_info(other)
625
+ new_comp.name = new_comp._make_combined_name()
626
+
627
+ names_and_props = [
628
+ ("state_names", state_names),
629
+ ("data_names", data_names),
630
+ ("param_names", param_names),
631
+ ("shock_names", shock_names),
632
+ ("param_dims", param_dims),
633
+ ("coords", coords),
634
+ ("param_dims", param_dims),
635
+ ("param_info", param_info),
636
+ ("data_info", data_info),
637
+ ("exog_names", exog_names),
638
+ ("_name_to_variable", _name_to_variable),
639
+ ("_name_to_data", _name_to_data),
640
+ ]
641
+
642
+ for prop, value in names_and_props:
643
+ setattr(new_comp, prop, value)
644
+
645
+ return new_comp
646
+
647
+ def build(self, name=None, filter_type="standard", verbose=True):
648
+ """
649
+ Build a StructuralTimeSeries statespace model from the current component(s)
650
+
651
+ Parameters
652
+ ----------
653
+ name: str, optional
654
+ Name of the exogenous data being modeled. Default is "data"
655
+
656
+ filter_type : str, optional
657
+ The type of Kalman filter to use. Valid options are "standard", "univariate", "single", "cholesky", and
658
+ "steady_state". For more information, see the docs for each filter. Default is "standard".
659
+
660
+ verbose : bool, optional
661
+ If True, displays information about the initialized model. Defaults to True.
662
+
663
+ Returns
664
+ -------
665
+ PyMCStateSpace
666
+ An initialized instance of a PyMCStateSpace, constructed using the system matrices contained in the
667
+ components.
668
+ """
669
+
670
+ return StructuralTimeSeries(
671
+ self.ssm,
672
+ name=name,
673
+ state_names=self.state_names,
674
+ data_names=self.data_names,
675
+ shock_names=self.shock_names,
676
+ param_names=self.param_names,
677
+ param_dims=self.param_dims,
678
+ coords=self.coords,
679
+ param_info=self.param_info,
680
+ data_info=self.data_info,
681
+ component_info=self._component_info,
682
+ measurement_error=self.measurement_error,
683
+ exog_names=self.exog_names,
684
+ name_to_variable=self._name_to_variable,
685
+ name_to_data=self._name_to_data,
686
+ filter_type=filter_type,
687
+ verbose=verbose,
688
+ )
689
+
690
+
691
+ class LevelTrendComponent(Component):
692
+ r"""
693
+ Level and trend component of a structural time series model
694
+
695
+ Parameters
696
+ ----------
697
+ __________
698
+ order : int
699
+
700
+ Number of time derivatives of the trend to include in the model. For example, when order=3, the trend will
701
+ be of the form ``y = a + b * t + c * t ** 2``, where the coefficients ``a, b, c`` come from the initial
702
+ state values.
703
+
704
+ innovations_order : int or sequence of int, optional
705
+
706
+ The number of stochastic innovations to include in the model. By default, ``innovations_order = order``
707
+
708
+ Notes
709
+ -----
710
+ This class implements the level and trend components of the general structural time series model. In the most
711
+ general form, the level and trend is described by a system of two time-varying equations.
712
+
713
+ .. math::
714
+ \begin{align}
715
+ \mu_{t+1} &= \mu_t + \nu_t + \zeta_t \\
716
+ \nu_{t+1} &= \nu_t + \xi_t
717
+ \zeta_t &\sim N(0, \sigma_\zeta) \\
718
+ \xi_t &\sim N(0, \sigma_\xi)
719
+ \end{align}
720
+
721
+ Where :math:`\mu_{t+1}` is the mean of the timeseries at time t, and :math:`\nu_t` is the drift or the slope of
722
+ the process. When both innovations :math:`\zeta_t` and :math:`\xi_t` are included in the model, it is known as a
723
+ *local linear trend* model. This system of two equations, corresponding to ``order=2``, can be expanded or
724
+ contracted by adding or removing equations. ``order=3`` would add an acceleration term to the sytsem:
725
+
726
+ .. math::
727
+ \begin{align}
728
+ \mu_{t+1} &= \mu_t + \nu_t + \zeta_t \\
729
+ \nu_{t+1} &= \nu_t + \eta_t + \xi_t \\
730
+ \eta_{t+1} &= \eta_{t-1} + \omega_t \\
731
+ \zeta_t &\sim N(0, \sigma_\zeta) \\
732
+ \xi_t &\sim N(0, \sigma_\xi) \\
733
+ \omega_t &\sim N(0, \sigma_\omega)
734
+ \end{align}
735
+
736
+ After setting all innovation terms to zero and defining initial states :math:`\mu_0, \nu_0, \eta_0`, these equations
737
+ can be collapsed to:
738
+
739
+ .. math::
740
+ \mu_t = \mu_0 + \nu_0 \cdot t + \eta_0 \cdot t^2
741
+
742
+ Which clarifies how the order and initial states influence the model. In particular, the initial states are the
743
+ coefficients on the intercept, slope, acceleration, and so on.
744
+
745
+ In this light, allowing for innovations can be understood as allowing these coefficients to vary over time. Each
746
+ component can be individually selected for time variation by passing a list to the ``innovations_order`` argument.
747
+ For example, a constant intercept with time varying trend and acceleration is specified as ``order=3,
748
+ innovations_order=[0, 1, 1]``.
749
+
750
+ By choosing the ``order`` and ``innovations_order``, a large variety of models can be obtained. Notable
751
+ models include:
752
+
753
+ * Constant intercept, ``order=1, innovations_order=0``
754
+
755
+ .. math::
756
+ \mu_t = \mu
757
+
758
+ * Constant linear slope, ``order=2, innovations_order=0``
759
+
760
+ .. math::
761
+ \mu_t = \mu_{t-1} + \nu
762
+
763
+ * Gaussian Random Walk, ``order=1, innovations_order=1``
764
+
765
+ .. math::
766
+ \mu_t = \mu_{t-1} + \zeta_t
767
+
768
+ * Gaussian Random Walk with Drift, ``order=2, innovations_order=1``
769
+
770
+ .. math::
771
+ \mu_t = \mu_{t-1} + \nu + \zeta_t
772
+
773
+ * Smooth Trend, ``order=2, innovations_order=[0, 1]``
774
+
775
+ .. math::
776
+ \begin{align}
777
+ \mu_t &= \mu_{t-1} + \nu_{t-1} \\
778
+ \nu_t &= \nu_{t-1} + \xi_t
779
+ \end{align}
780
+
781
+ * Local Level, ``order=2, innovations_order=2``
782
+
783
+ [1] notes that the smooth trend model produces more gradually changing slopes than the full local linear trend
784
+ model, and is equivalent to an "integrated trend model".
785
+
786
+ References
787
+ ----------
788
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
789
+ Time Series Analysis by State Space Methods: Second Edition.
790
+ Oxford University Press.
791
+
792
+ """
793
+
794
+ def __init__(
795
+ self,
796
+ order: int | list[int] = 2,
797
+ innovations_order: int | list[int] | None = None,
798
+ name: str = "LevelTrend",
799
+ ):
800
+ if innovations_order is None:
801
+ innovations_order = order
802
+
803
+ self._order_mask = order_to_mask(order)
804
+ max_state = np.flatnonzero(self._order_mask)[-1].item() + 1
805
+
806
+ # If the user passes excess zeros, raise an error. The alternative is to prune them, but this would cause
807
+ # the shape of the state to be different to what the user expects.
808
+ if len(self._order_mask) > max_state:
809
+ raise ValueError(
810
+ f"order={order} is invalid. The highest derivative should not be set to zero. If you want a "
811
+ f"lower order model, explicitly omit the zeros."
812
+ )
813
+ k_states = max_state
814
+
815
+ if isinstance(innovations_order, int):
816
+ n = innovations_order
817
+ innovations_order = order_to_mask(k_states)
818
+ if n > 0:
819
+ innovations_order[n:] = False
820
+ else:
821
+ innovations_order[:] = False
822
+ else:
823
+ innovations_order = order_to_mask(innovations_order)
824
+
825
+ self.innovations_order = innovations_order[:max_state]
826
+ k_posdef = int(sum(innovations_order))
827
+
828
+ super().__init__(
829
+ name,
830
+ k_endog=1,
831
+ k_states=k_states,
832
+ k_posdef=k_posdef,
833
+ measurement_error=False,
834
+ combine_hidden_states=False,
835
+ obs_state_idxs=np.array([1.0] + [0.0] * (k_states - 1)),
836
+ )
837
+
838
+ def populate_component_properties(self):
839
+ name_slice = POSITION_DERIVATIVE_NAMES[: self.k_states]
840
+ self.param_names = ["initial_trend"]
841
+ self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
842
+ self.param_dims = {"initial_trend": ("trend_state",)}
843
+ self.coords = {"trend_state": self.state_names}
844
+ self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}}
845
+
846
+ if self.k_posdef > 0:
847
+ self.param_names += ["sigma_trend"]
848
+ self.shock_names = [
849
+ name for name, mask in zip(name_slice, self.innovations_order) if mask
850
+ ]
851
+ self.param_dims["sigma_trend"] = ("trend_shock",)
852
+ self.coords["trend_shock"] = self.shock_names
853
+ self.param_info["sigma_trend"] = {"shape": (self.k_posdef,), "constraints": "Positive"}
854
+
855
+ for name in self.param_names:
856
+ self.param_info[name]["dims"] = self.param_dims[name]
857
+
858
+ def make_symbolic_graph(self) -> None:
859
+ initial_trend = self.make_and_register_variable("initial_trend", shape=(self.k_states,))
860
+ self.ssm["initial_state", :] = initial_trend
861
+ triu_idx = np.triu_indices(self.k_states)
862
+ self.ssm[np.s_["transition", triu_idx[0], triu_idx[1]]] = 1
863
+
864
+ R = np.eye(self.k_states)
865
+ R = R[:, self.innovations_order]
866
+ self.ssm["selection", :, :] = R
867
+
868
+ self.ssm["design", 0, :] = np.array([1.0] + [0.0] * (self.k_states - 1))
869
+
870
+ if self.k_posdef > 0:
871
+ sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,))
872
+ diag_idx = np.diag_indices(self.k_posdef)
873
+ idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
874
+ self.ssm[idx] = sigma_trend**2
875
+
876
+
877
+ class MeasurementError(Component):
878
+ r"""
879
+ Measurement error term for a structural timeseries model
880
+
881
+ Parameters
882
+ ----------
883
+ name: str, optional
884
+
885
+ Name of the observed data. Default is "obs".
886
+
887
+ Notes
888
+ -----
889
+ This component should only be used in combination with other components, because it has no states. It's only use
890
+ is to add a variance parameter to the model, associated with the observation noise matrix H.
891
+
892
+ Examples
893
+ --------
894
+ Create and estimate a deterministic linear trend with measurement error
895
+
896
+ .. code:: python
897
+
898
+ from pymc_extras.statespace import structural as st
899
+ import pymc as pm
900
+ import pytensor.tensor as pt
901
+
902
+ trend = st.LevelTrendComponent(order=2, innovations_order=0)
903
+ error = st.MeasurementError()
904
+ ss_mod = (trend + error).build()
905
+
906
+ with pm.Model(coords=ss_mod.coords) as model:
907
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
908
+ intitial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
909
+ sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs'])
910
+
911
+ ss_mod.build_statespace_graph(data, mode='JAX')
912
+ idata = pm.sample(nuts_sampler='numpyro')
913
+ """
914
+
915
+ def __init__(self, name: str = "MeasurementError"):
916
+ k_endog = 1
917
+ k_states = 0
918
+ k_posdef = 0
919
+
920
+ super().__init__(
921
+ name, k_endog, k_states, k_posdef, measurement_error=True, combine_hidden_states=False
922
+ )
923
+
924
+ def populate_component_properties(self):
925
+ self.param_names = [f"sigma_{self.name}"]
926
+ self.param_dims = {}
927
+ self.param_info = {
928
+ f"sigma_{self.name}": {
929
+ "shape": (),
930
+ "constraints": "Positive",
931
+ "dims": None,
932
+ }
933
+ }
934
+
935
+ def make_symbolic_graph(self) -> None:
936
+ sigma_shape = ()
937
+ error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
938
+ diag_idx = np.diag_indices(self.k_endog)
939
+ idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]
940
+ self.ssm[idx] = error_sigma**2
941
+
942
+
943
+ class AutoregressiveComponent(Component):
944
+ r"""
945
+ Autoregressive timeseries component
946
+
947
+ Parameters
948
+ ----------
949
+ order: int or sequence of int
950
+
951
+ If int, the number of lags to include in the model.
952
+ If a sequence, an array-like of zeros and ones indicating which lags to include in the model.
953
+
954
+ Notes
955
+ -----
956
+ An autoregressive component can be thought of as a way o introducing serially correlated errors into the model.
957
+ The process is modeled:
958
+
959
+ .. math::
960
+ x_t = \sum_{i=1}^p \rho_i x_{t-i}
961
+
962
+ Where ``p``, the number of autoregressive terms to model, is the order of the process. By default, all lags up to
963
+ ``p`` are included in the model. To disable lags, pass a list of zeros and ones to the ``order`` argumnet. For
964
+ example, ``order=[1, 1, 0, 1]`` would become:
965
+
966
+ .. math::
967
+ x_t = \rho_1 x_{t-1} + \rho_2 x_{t-1} + \rho_4 x_{t-1}
968
+
969
+ The coefficient :math:`\rho_3` has been constrained to zero.
970
+
971
+ .. warning:: This class is meant to be used as a component in a structural time series model. For modeling of
972
+ stationary processes with ARIMA, use ``statespace.BayesianSARIMA``.
973
+
974
+ Examples
975
+ --------
976
+ Model a timeseries as an AR(2) process with non-zero mean:
977
+
978
+ .. code:: python
979
+
980
+ from pymc_extras.statespace import structural as st
981
+ import pymc as pm
982
+ import pytensor.tensor as pt
983
+
984
+ trend = st.LevelTrendComponent(order=1, innovations_order=0)
985
+ ar = st.AutoregressiveComponent(2)
986
+ ss_mod = (trend + ar).build()
987
+
988
+ with pm.Model(coords=ss_mod.coords) as model:
989
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
990
+ intitial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
991
+ ar_params = pm.Normal('ar_params', dims=ss_mod.param_dims['ar_params'])
992
+ sigma_ar = pm.Exponential('sigma_ar', 1, dims=ss_mod.param_dims['sigma_ar'])
993
+
994
+ ss_mod.build_statespace_graph(data, mode='JAX')
995
+ idata = pm.sample(nuts_sampler='numpyro')
996
+
997
+ """
998
+
999
+ def __init__(self, order: int = 1, name: str = "AutoRegressive"):
1000
+ order = order_to_mask(order)
1001
+ ar_lags = np.flatnonzero(order).ravel().astype(int) + 1
1002
+ k_states = len(order)
1003
+
1004
+ self.order = order
1005
+ self.ar_lags = ar_lags
1006
+
1007
+ super().__init__(
1008
+ name=name,
1009
+ k_endog=1,
1010
+ k_states=k_states,
1011
+ k_posdef=1,
1012
+ measurement_error=True,
1013
+ combine_hidden_states=True,
1014
+ obs_state_idxs=np.r_[[1.0], np.zeros(k_states - 1)],
1015
+ )
1016
+
1017
+ def populate_component_properties(self):
1018
+ self.state_names = [f"L{i + 1}.data" for i in range(self.k_states)]
1019
+ self.shock_names = [f"{self.name}_innovation"]
1020
+ self.param_names = ["ar_params", "sigma_ar"]
1021
+ self.param_dims = {"ar_params": (AR_PARAM_DIM,)}
1022
+ self.coords = {AR_PARAM_DIM: self.ar_lags.tolist()}
1023
+
1024
+ self.param_info = {
1025
+ "ar_params": {
1026
+ "shape": (self.k_states,),
1027
+ "constraints": None,
1028
+ "dims": (AR_PARAM_DIM,),
1029
+ },
1030
+ "sigma_ar": {"shape": (), "constraints": "Positive", "dims": None},
1031
+ }
1032
+
1033
+ def make_symbolic_graph(self) -> None:
1034
+ k_nonzero = int(sum(self.order))
1035
+ ar_params = self.make_and_register_variable("ar_params", shape=(k_nonzero,))
1036
+ sigma_ar = self.make_and_register_variable("sigma_ar", shape=())
1037
+
1038
+ T = np.eye(self.k_states, k=-1)
1039
+ self.ssm["transition", :, :] = T
1040
+ self.ssm["selection", 0, 0] = 1
1041
+ self.ssm["design", 0, 0] = 1
1042
+
1043
+ ar_idx = ("transition", np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
1044
+ self.ssm[ar_idx] = ar_params
1045
+
1046
+ cov_idx = ("state_cov", *np.diag_indices(1))
1047
+ self.ssm[cov_idx] = sigma_ar**2
1048
+
1049
+
1050
+ class TimeSeasonality(Component):
1051
+ r"""
1052
+ Seasonal component, modeled in the time domain
1053
+
1054
+ Parameters
1055
+ ----------
1056
+ season_length: int
1057
+ The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
1058
+ daily data with weekly seasonal pattern, etc.
1059
+
1060
+ innovations: bool, default True
1061
+ Whether to include stochastic innovations in the strength of the seasonal effect
1062
+
1063
+ name: str, default None
1064
+ A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
1065
+ components are included in the same model. Default is ``f"Seasonal[s={season_length}]"``
1066
+
1067
+ state_names: list of str, default None
1068
+ List of strings for seasonal effect labels. If provided, it must be of length ``season_length``. An example
1069
+ would be ``state_names = ['Mon', 'Tue', 'Wed', 'Thur', 'Fri', 'Sat', 'Sun']`` when data is daily with a weekly
1070
+ seasonal pattern (``season_length = 7``).
1071
+
1072
+ If None, states will be numbered ``[State_0, ..., State_s]``
1073
+
1074
+ Notes
1075
+ -----
1076
+ A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
1077
+ model seasonal effects, the implementation used here is the one described by [1] as the "canonical" time domain
1078
+ representation. The seasonal component can be expressed:
1079
+
1080
+ .. math::
1081
+ \gamma_t = -\sum_{i=1}^{s-1} \gamma_{t-i} + \omega_t, \quad \omega_t \sim N(0, \sigma_\gamma)
1082
+
1083
+ Where :math:`s` is the ``seasonal_length`` parameter and :math:`\omega_t` is the (optional) stochastic innovation.
1084
+ To give interpretation to the :math:`\gamma` terms, it is helpful to work through the algebra for a simple
1085
+ example. Let :math:`s=4`, and omit the shock term. Define initial conditions :math:`\gamma_0, \gamma_{-1},
1086
+ \gamma_{-2}`. The value of the seasonal component for the first 5 timesteps will be:
1087
+
1088
+ .. math::
1089
+ \begin{align}
1090
+ \gamma_1 &= -\gamma_0 - \gamma_{-1} - \gamma_{-2} \\
1091
+ \gamma_2 &= -\gamma_1 - \gamma_0 - \gamma_{-1} \\
1092
+ &= -(-\gamma_0 - \gamma_{-1} - \gamma_{-2}) - \gamma_0 - \gamma_{-1} \\
1093
+ &= (\gamma_0 - \gamma_0 )+ (\gamma_{-1} - \gamma_{-1}) + \gamma_{-2} \\
1094
+ &= \gamma_{-2} \\
1095
+ \gamma_3 &= -\gamma_2 - \gamma_1 - \gamma_0 \\
1096
+ &= -\gamma_{-2} - (-\gamma_0 - \gamma_{-1} - \gamma_{-2}) - \gamma_0 \\
1097
+ &= (\gamma_{-2} - \gamma_{-2}) + \gamma_{-1} + (\gamma_0 - \gamma_0) \\
1098
+ &= \gamma_{-1} \\
1099
+ \gamma_4 &= -\gamma_3 - \gamma_2 - \gamma_1 \\
1100
+ &= -\gamma_{-1} - \gamma_{-2} -(-\gamma_0 - \gamma_{-1} - \gamma_{-2}) \\
1101
+ &= (\gamma_{-2} - \gamma_{-2}) + (\gamma_{-1} - \gamma_{-1}) + \gamma_0 \\
1102
+ &= \gamma_0 \\
1103
+ \gamma_5 &= -\gamma_4 - \gamma_3 - \gamma_2 \\
1104
+ &= -\gamma_0 - \gamma_{-1} - \gamma_{-2} \\
1105
+ &= \gamma_1
1106
+ \end{align}
1107
+
1108
+ This exercise shows that, given a list ``initial_conditions`` of length ``s-1``, the effects of this model will be:
1109
+
1110
+ - Period 1: ``-sum(initial_conditions)``
1111
+ - Period 2: ``initial_conditions[-1]``
1112
+ - Period 3: ``initial_conditions[-2]``
1113
+ - ...
1114
+ - Period s: ``initial_conditions[0]``
1115
+ - Period s+1: ``-sum(initial_condition)``
1116
+
1117
+ And so on. So for interpretation, the ``season_length - 1`` initial states are, when reversed, the coefficients
1118
+ associated with ``state_names[1:]``.
1119
+
1120
+ .. warning::
1121
+ Although the ``state_names`` argument expects a list of length ``season_length``, only ``state_names[1:]``
1122
+ will be saved as model dimensions, since the 1st coefficient is not identified (it is defined as
1123
+ :math:`-\sum_{i=1}^{s} \gamma_{t-i}`).
1124
+
1125
+ Examples
1126
+ --------
1127
+ Estimate monthly with a model with a gaussian random walk trend and monthly seasonality:
1128
+
1129
+ .. code:: python
1130
+
1131
+ from pymc_extras.statespace import structural as st
1132
+ import pymc as pm
1133
+ import pytensor.tensor as pt
1134
+ import pandas as pd
1135
+
1136
+ # Get month names
1137
+ state_names = pd.date_range('1900-01-01', '1900-12-31', freq='MS').month_name().tolist()
1138
+
1139
+ # Build the structural model
1140
+ grw = st.LevelTrendComponent(order=1, innovations_order=1)
1141
+ annual_season = st.TimeSeasonality(season_length=12, name='annual', state_names=state_names, innovations=False)
1142
+ ss_mod = (grw + annual_season).build()
1143
+
1144
+ # Estimate with PyMC
1145
+ with pm.Model(coords=ss_mod.coords) as model:
1146
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
1147
+ intitial_trend = pm.Deterministic('initial_trend', pt.zeros(1), dims=ss_mod.param_dims['initial_trend'])
1148
+ annual_coefs = pm.Normal('annual_coefs', sigma=1e-2, dims=ss_mod.param_dims['annual_coefs'])
1149
+ trend_sigmas = pm.HalfNormal('trend_sigmas', sigma=1e-6, dims=ss_mod.param_dims['trend_sigmas'])
1150
+ ss_mod.build_statespace_graph(data, mode='JAX')
1151
+ idata = pm.sample(nuts_sampler='numpyro')
1152
+
1153
+ References
1154
+ ----------
1155
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
1156
+ Time Series Analysis by State Space Methods: Second Edition.
1157
+ Oxford University Press.
1158
+ """
1159
+
1160
+ def __init__(
1161
+ self,
1162
+ season_length: int,
1163
+ innovations: bool = True,
1164
+ name: str | None = None,
1165
+ state_names: list | None = None,
1166
+ pop_state: bool = True,
1167
+ ):
1168
+ if name is None:
1169
+ name = f"Seasonal[s={season_length}]"
1170
+ if state_names is None:
1171
+ state_names = [f"{name}_{i}" for i in range(season_length)]
1172
+ else:
1173
+ if len(state_names) != season_length:
1174
+ raise ValueError(
1175
+ f"state_names must be a list of length season_length, got {len(state_names)}"
1176
+ )
1177
+ state_names = state_names.copy()
1178
+ self.innovations = innovations
1179
+ self.pop_state = pop_state
1180
+
1181
+ if self.pop_state:
1182
+ # In traditional models, the first state isn't identified, so we can help out the user by automatically
1183
+ # discarding it.
1184
+ # TODO: Can this be stashed and reconstructed automatically somehow?
1185
+ state_names.pop(0)
1186
+ k_states = season_length - 1
1187
+
1188
+ super().__init__(
1189
+ name=name,
1190
+ k_endog=1,
1191
+ k_states=k_states,
1192
+ k_posdef=int(innovations),
1193
+ state_names=state_names,
1194
+ measurement_error=False,
1195
+ combine_hidden_states=True,
1196
+ obs_state_idxs=np.r_[[1.0], np.zeros(k_states - 1)],
1197
+ )
1198
+
1199
+ def populate_component_properties(self):
1200
+ self.param_names = [f"{self.name}_coefs"]
1201
+ self.param_info = {
1202
+ f"{self.name}_coefs": {
1203
+ "shape": (self.k_states,),
1204
+ "constraints": None,
1205
+ "dims": (f"{self.name}_state",),
1206
+ }
1207
+ }
1208
+ self.param_dims = {f"{self.name}_coefs": (f"{self.name}_state",)}
1209
+ self.coords = {f"{self.name}_state": self.state_names}
1210
+
1211
+ if self.innovations:
1212
+ self.param_names += [f"sigma_{self.name}"]
1213
+ self.param_info[f"sigma_{self.name}"] = {
1214
+ "shape": (),
1215
+ "constraints": "Positive",
1216
+ "dims": None,
1217
+ }
1218
+ self.shock_names = [f"{self.name}"]
1219
+
1220
+ def make_symbolic_graph(self) -> None:
1221
+ T = np.eye(self.k_states, k=-1)
1222
+ T[0, :] = -1
1223
+
1224
+ self.ssm["transition", :, :] = T
1225
+ self.ssm["design", 0, 0] = 1
1226
+
1227
+ initial_states = self.make_and_register_variable(
1228
+ f"{self.name}_coefs", shape=(self.k_states,)
1229
+ )
1230
+ self.ssm["initial_state", np.arange(self.k_states, dtype=int)] = initial_states
1231
+
1232
+ if self.innovations:
1233
+ self.ssm["selection", 0, 0] = 1
1234
+ season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=())
1235
+ cov_idx = ("state_cov", *np.diag_indices(1))
1236
+ self.ssm[cov_idx] = season_sigma**2
1237
+
1238
+
1239
+ class FrequencySeasonality(Component):
1240
+ r"""
1241
+ Seasonal component, modeled in the frequency domain
1242
+
1243
+ Parameters
1244
+ ----------
1245
+ season_length: float
1246
+ The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
1247
+ daily data with weekly seasonal pattern, etc. Non-integer seasonal_length is also permitted, for example
1248
+ 365.2422 days in a (solar) year.
1249
+
1250
+ n: int
1251
+ Number of fourier features to include in the seasonal component. Default is ``season_length // 2``, which
1252
+ is the maximum possible. A smaller number can be used for a more wave-like seasonal pattern.
1253
+
1254
+ name: str, default None
1255
+ A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
1256
+ components are included in the same model. Default is ``f"Seasonal[s={season_length}, n={n}]"``
1257
+
1258
+ innovations: bool, default True
1259
+ Whether to include stochastic innovations in the strength of the seasonal effect
1260
+
1261
+ Notes
1262
+ -----
1263
+ A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
1264
+ model seasonal effects, the implementation used here is the one described by [1] as the "canonical" frequency domain
1265
+ representation. The seasonal component can be expressed:
1266
+
1267
+ .. math::
1268
+ \begin{align}
1269
+ \gamma_t &= \sum_{j=1}^{2n} \gamma_{j,t} \\
1270
+ \gamma_{j, t+1} &= \gamma_{j,t} \cos \lambda_j + \gamma_{j,t}^\star \sin \lambda_j + \omega_{j, t} \\
1271
+ \gamma_{j, t}^\star &= -\gamma_{j,t} \sin \lambda_j + \gamma_{j,t}^\star \cos \lambda_j + \omega_{j,t}^\star
1272
+ \lambda_j &= \frac{2\pi j}{s}
1273
+ \end{align}
1274
+
1275
+ Where :math:`s` is the ``seasonal_length``.
1276
+
1277
+ Unlike a ``TimeSeasonality`` component, a ``FrequencySeasonality`` component does not require integer season
1278
+ length. In addition, for long seasonal periods, it is possible to obtain a more compact state space representation
1279
+ by choosing ``n << s // 2``. Using ``TimeSeasonality``, an annual seasonal pattern in daily data requires 364
1280
+ states, whereas ``FrequencySeasonality`` always requires ``2 * n`` states, regardless of the ``seasonal_length``.
1281
+ The price of this compactness is less representational power. At ``n = 1``, the seasonal pattern will be a pure
1282
+ sine wave. At ``n = s // 2``, any arbitrary pattern can be represented.
1283
+
1284
+ One cost of the added flexibility of ``FrequencySeasonality`` is reduced interpretability. States of this model are
1285
+ coefficients :math:`\gamma_1, \gamma^\star_1, \gamma_2, \gamma_2^\star ..., \gamma_n, \gamma^\star_n` associated
1286
+ with different frequencies in the fourier representation of the seasonal pattern. As a result, it is not possible
1287
+ to isolate and identify a "Monday" effect, for instance.
1288
+ """
1289
+
1290
+ def __init__(self, season_length, n=None, name=None, innovations=True):
1291
+ if n is None:
1292
+ n = int(season_length // 2)
1293
+ if name is None:
1294
+ name = f"Frequency[s={season_length}, n={n}]"
1295
+
1296
+ k_states = n * 2
1297
+ self.n = n
1298
+ self.season_length = season_length
1299
+ self.innovations = innovations
1300
+
1301
+ # If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
1302
+ # get a parameter assigned to it and should just be fixed to zero.
1303
+ # Test this way (rather than n == s // 2) to catch cases when n is non-integer.
1304
+ self.last_state_not_identified = self.season_length / self.n == 2.0
1305
+ self.n_coefs = k_states - int(self.last_state_not_identified)
1306
+
1307
+ obs_state_idx = np.zeros(k_states)
1308
+ obs_state_idx[slice(0, k_states, 2)] = 1
1309
+
1310
+ super().__init__(
1311
+ name=name,
1312
+ k_endog=1,
1313
+ k_states=k_states,
1314
+ k_posdef=k_states * int(self.innovations),
1315
+ measurement_error=False,
1316
+ combine_hidden_states=True,
1317
+ obs_state_idxs=obs_state_idx,
1318
+ )
1319
+
1320
+ def make_symbolic_graph(self) -> None:
1321
+ self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
1322
+
1323
+ init_state = self.make_and_register_variable(f"{self.name}", shape=(self.n_coefs,))
1324
+
1325
+ init_state_idx = np.arange(self.n_coefs, dtype=int)
1326
+ self.ssm["initial_state", init_state_idx] = init_state
1327
+
1328
+ T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
1329
+ T = pt.linalg.block_diag(*T_mats)
1330
+ self.ssm["transition", :, :] = T
1331
+
1332
+ if self.innovations:
1333
+ sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=())
1334
+ self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season**2
1335
+ self.ssm["selection", :, :] = np.eye(self.k_states)
1336
+
1337
+ def populate_component_properties(self):
1338
+ self.state_names = [f"{self.name}_{f}_{i}" for i in range(self.n) for f in ["Cos", "Sin"]]
1339
+ self.param_names = [f"{self.name}"]
1340
+
1341
+ self.param_dims = {self.name: (f"{self.name}_state",)}
1342
+ self.param_info = {
1343
+ f"{self.name}": {
1344
+ "shape": (self.k_states - int(self.last_state_not_identified),),
1345
+ "constraints": None,
1346
+ "dims": (f"{self.name}_state",),
1347
+ }
1348
+ }
1349
+
1350
+ init_state_idx = np.arange(self.k_states, dtype=int)
1351
+ if self.last_state_not_identified:
1352
+ init_state_idx = init_state_idx[:-1]
1353
+ self.coords = {f"{self.name}_state": [self.state_names[i] for i in init_state_idx]}
1354
+
1355
+ if self.innovations:
1356
+ self.shock_names = self.state_names.copy()
1357
+ self.param_names += [f"sigma_{self.name}"]
1358
+ self.param_info[f"sigma_{self.name}"] = {
1359
+ "shape": (),
1360
+ "constraints": "Positive",
1361
+ "dims": None,
1362
+ }
1363
+
1364
+
1365
+ class CycleComponent(Component):
1366
+ r"""
1367
+ A component for modeling longer-term cyclical effects
1368
+
1369
+ Parameters
1370
+ ----------
1371
+ name: str
1372
+ Name of the component. Used in generated coordinates and state names. If None, a descriptive name will be
1373
+ used.
1374
+
1375
+ cycle_length: int, optional
1376
+ The length of the cycle, in the calendar units of your data. For example, if your data is monthly, and you
1377
+ want to model a 12-month cycle, use ``cycle_length=12``. You cannot specify both ``cycle_length`` and
1378
+ ``estimate_cycle_length``.
1379
+
1380
+ estimate_cycle_length: bool, default False
1381
+ Whether to estimate the cycle length. If True, an additional parameter, ``cycle_length`` will be added to the
1382
+ model. You cannot specify both ``cycle_length`` and ``estimate_cycle_length``.
1383
+
1384
+ dampen: bool, default False
1385
+ Whether to dampen the cycle by multiplying by a dampening factor :math:`\rho` at every timestep. If true,
1386
+ an additional parameter, ``dampening_factor`` will be added to the model.
1387
+
1388
+ innovations: bool, default True
1389
+ Whether to include stochastic innovations in the strength of the seasonal effect. If True, an additional
1390
+ parameter, ``sigma_{name}`` will be added to the model.
1391
+
1392
+ Notes
1393
+ -----
1394
+ The cycle component is very similar in implementation to the frequency domain seasonal component, expect that it
1395
+ is restricted to n=1. The cycle component can be expressed:
1396
+
1397
+ .. math::
1398
+ \begin{align}
1399
+ \gamma_t &= \rho \gamma_{t-1} \cos \lambda + \rho \gamma_{t-1}^\star \sin \lambda + \omega_{t} \\
1400
+ \gamma_{t}^\star &= -\rho \gamma_{t-1} \sin \lambda + \rho \gamma_{t-1}^\star \cos \lambda + \omega_{t}^\star \\
1401
+ \lambda &= \frac{2\pi}{s}
1402
+ \end{align}
1403
+
1404
+ Where :math:`s` is the ``cycle_length``. [1] recommend that this component be used for longer term cyclical
1405
+ effects, such as business cycles, and that the seasonal component be used for shorter term effects, such as
1406
+ weekly or monthly seasonality.
1407
+
1408
+ Unlike a FrequencySeasonality component, the length of a CycleComponent can be estimated.
1409
+
1410
+ Examples
1411
+ --------
1412
+ Estimate a business cycle with length between 6 and 12 years:
1413
+
1414
+ .. code:: python
1415
+
1416
+ from pymc_extras.statespace import structural as st
1417
+ import pymc as pm
1418
+ import pytensor.tensor as pt
1419
+ import pandas as pd
1420
+ import numpy as np
1421
+
1422
+ data = np.random.normal(size=(100, 1))
1423
+
1424
+ # Build the structural model
1425
+ grw = st.LevelTrendComponent(order=1, innovations_order=1)
1426
+ cycle = st.CycleComponent('business_cycle', estimate_cycle_length=True, dampen=False)
1427
+ ss_mod = (grw + cycle).build()
1428
+
1429
+ # Estimate with PyMC
1430
+ with pm.Model(coords=ss_mod.coords) as model:
1431
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states), dims=ss_mod.param_dims['P0'])
1432
+ intitial_trend = pm.Normal('initial_trend', dims=ss_mod.param_dims['initial_trend'])
1433
+ sigma_trend = pm.HalfNormal('sigma_trend', dims=ss_mod.param_dims['sigma_trend'])
1434
+
1435
+ cycle_strength = pm.Normal('business_cycle')
1436
+ cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
1437
+
1438
+ sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
1439
+ ss_mod.build_statespace_graph(data, mode='JAX')
1440
+
1441
+ idata = pm.sample(nuts_sampler='numpyro')
1442
+
1443
+ References
1444
+ ----------
1445
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
1446
+ Time Series Analysis by State Space Methods: Second Edition.
1447
+ Oxford University Press.
1448
+ """
1449
+
1450
+ def __init__(
1451
+ self,
1452
+ name: str | None = None,
1453
+ cycle_length: int | None = None,
1454
+ estimate_cycle_length: bool = False,
1455
+ dampen: bool = False,
1456
+ innovations: bool = True,
1457
+ ):
1458
+ if cycle_length is None and not estimate_cycle_length:
1459
+ raise ValueError("Must specify cycle_length if estimate_cycle_length is False")
1460
+ if cycle_length is not None and estimate_cycle_length:
1461
+ raise ValueError("Cannot specify cycle_length if estimate_cycle_length is True")
1462
+ if name is None:
1463
+ cycle = int(cycle_length) if cycle_length is not None else "Estimate"
1464
+ name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]"
1465
+
1466
+ self.estimate_cycle_length = estimate_cycle_length
1467
+ self.cycle_length = cycle_length
1468
+ self.innovations = innovations
1469
+ self.dampen = dampen
1470
+ self.n_coefs = 1
1471
+
1472
+ k_states = 2
1473
+ k_endog = 1
1474
+ k_posdef = 2
1475
+
1476
+ obs_state_idx = np.zeros(k_states)
1477
+ obs_state_idx[slice(0, k_states, 2)] = 1
1478
+
1479
+ super().__init__(
1480
+ name=name,
1481
+ k_endog=k_endog,
1482
+ k_states=k_states,
1483
+ k_posdef=k_posdef,
1484
+ measurement_error=False,
1485
+ combine_hidden_states=True,
1486
+ obs_state_idxs=obs_state_idx,
1487
+ )
1488
+
1489
+ def make_symbolic_graph(self) -> None:
1490
+ self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
1491
+ self.ssm["selection", :, :] = np.eye(self.k_states)
1492
+ self.param_dims = {self.name: (f"{self.name}_state",)}
1493
+ self.coords = {f"{self.name}_state": self.state_names}
1494
+
1495
+ init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,))
1496
+
1497
+ self.ssm["initial_state", :] = init_state
1498
+
1499
+ if self.estimate_cycle_length:
1500
+ lamb = self.make_and_register_variable(f"{self.name}_length", shape=())
1501
+ else:
1502
+ lamb = self.cycle_length
1503
+
1504
+ if self.dampen:
1505
+ rho = self.make_and_register_variable(f"{self.name}_dampening_factor", shape=())
1506
+ else:
1507
+ rho = 1
1508
+
1509
+ T = rho * _frequency_transition_block(lamb, j=1)
1510
+ self.ssm["transition", :, :] = T
1511
+
1512
+ if self.innovations:
1513
+ sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
1514
+ self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
1515
+
1516
+ def populate_component_properties(self):
1517
+ self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
1518
+ self.param_names = [f"{self.name}"]
1519
+
1520
+ self.param_info = {
1521
+ f"{self.name}": {
1522
+ "shape": (2,),
1523
+ "constraints": None,
1524
+ "dims": (f"{self.name}_state",),
1525
+ }
1526
+ }
1527
+
1528
+ if self.estimate_cycle_length:
1529
+ self.param_names += [f"{self.name}_length"]
1530
+ self.param_info[f"{self.name}_length"] = {
1531
+ "shape": (),
1532
+ "constraints": "Positive, non-zero",
1533
+ "dims": None,
1534
+ }
1535
+
1536
+ if self.dampen:
1537
+ self.param_names += [f"{self.name}_dampening_factor"]
1538
+ self.param_info[f"{self.name}_dampening_factor"] = {
1539
+ "shape": (),
1540
+ "constraints": "0 < x ≤ 1",
1541
+ "dims": None,
1542
+ }
1543
+
1544
+ if self.innovations:
1545
+ self.param_names += [f"sigma_{self.name}"]
1546
+ self.param_info[f"sigma_{self.name}"] = {
1547
+ "shape": (),
1548
+ "constraints": "Positive",
1549
+ "dims": None,
1550
+ }
1551
+ self.shock_names = self.state_names.copy()
1552
+
1553
+
1554
+ class RegressionComponent(Component):
1555
+ def __init__(
1556
+ self,
1557
+ k_exog: int | None = None,
1558
+ name: str | None = "Exogenous",
1559
+ state_names: list[str] | None = None,
1560
+ innovations=False,
1561
+ ):
1562
+ self.innovations = innovations
1563
+ k_exog = self._handle_input_data(k_exog, state_names, name)
1564
+
1565
+ k_states = k_exog
1566
+ k_endog = 1
1567
+ k_posdef = k_exog
1568
+
1569
+ super().__init__(
1570
+ name=name,
1571
+ k_endog=k_endog,
1572
+ k_states=k_states,
1573
+ k_posdef=k_posdef,
1574
+ state_names=self.state_names,
1575
+ measurement_error=False,
1576
+ combine_hidden_states=False,
1577
+ exog_names=[f"data_{name}"],
1578
+ obs_state_idxs=np.ones(k_states),
1579
+ )
1580
+
1581
+ @staticmethod
1582
+ def _get_state_names(k_exog: int | None, state_names: list[str] | None, name: str):
1583
+ if k_exog is None and state_names is None:
1584
+ raise ValueError("Must specify at least one of k_exog or state_names")
1585
+ if state_names is not None and k_exog is not None:
1586
+ if len(state_names) != k_exog:
1587
+ raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}")
1588
+ elif k_exog is None:
1589
+ k_exog = len(state_names)
1590
+ else:
1591
+ state_names = [f"{name}_{i + 1}" for i in range(k_exog)]
1592
+
1593
+ return k_exog, state_names
1594
+
1595
+ def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -> int:
1596
+ k_exog, state_names = self._get_state_names(k_exog, state_names, name)
1597
+ self.state_names = state_names
1598
+
1599
+ return k_exog
1600
+
1601
+ def make_symbolic_graph(self) -> None:
1602
+ betas = self.make_and_register_variable(f"beta_{self.name}", shape=(self.k_states,))
1603
+ regression_data = self.make_and_register_data(
1604
+ f"data_{self.name}", shape=(None, self.k_states)
1605
+ )
1606
+
1607
+ self.ssm["initial_state", :] = betas
1608
+ self.ssm["transition", :, :] = np.eye(self.k_states)
1609
+ self.ssm["selection", :, :] = np.eye(self.k_states)
1610
+ self.ssm["design"] = pt.expand_dims(regression_data, 1)
1611
+
1612
+ if self.innovations:
1613
+ sigma_beta = self.make_and_register_variable(
1614
+ f"sigma_beta_{self.name}", (self.k_states,)
1615
+ )
1616
+ row_idx, col_idx = np.diag_indices(self.k_states)
1617
+ self.ssm["state_cov", row_idx, col_idx] = sigma_beta**2
1618
+
1619
+ def populate_component_properties(self) -> None:
1620
+ self.shock_names = self.state_names
1621
+
1622
+ self.param_names = [f"beta_{self.name}"]
1623
+ self.data_names = [f"data_{self.name}"]
1624
+ self.param_dims = {
1625
+ f"beta_{self.name}": ("exog_state",),
1626
+ }
1627
+
1628
+ self.param_info = {
1629
+ f"beta_{self.name}": {
1630
+ "shape": (self.k_states,),
1631
+ "constraints": None,
1632
+ "dims": ("exog_state",),
1633
+ },
1634
+ }
1635
+
1636
+ self.data_info = {
1637
+ f"data_{self.name}": {
1638
+ "shape": (None, self.k_states),
1639
+ "dims": (TIME_DIM, "exog_state"),
1640
+ },
1641
+ }
1642
+ self.coords = {"exog_state": self.state_names}
1643
+
1644
+ if self.innovations:
1645
+ self.param_names += [f"sigma_beta_{self.name}"]
1646
+ self.param_dims[f"sigma_beta_{self.name}"] = "exog_state"
1647
+ self.param_info[f"sigma_beta_{self.name}"] = {
1648
+ "shape": (),
1649
+ "constraints": "Positive",
1650
+ "dims": ("exog_state",),
1651
+ }