pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. pymc_extras/distributions/timeseries.py +10 -10
  2. pymc_extras/inference/dadvi/dadvi.py +14 -83
  3. pymc_extras/inference/laplace_approx/laplace.py +187 -159
  4. pymc_extras/inference/pathfinder/pathfinder.py +12 -7
  5. pymc_extras/inference/smc/sampling.py +2 -2
  6. pymc_extras/model/marginal/distributions.py +4 -2
  7. pymc_extras/model/marginal/marginal_model.py +12 -2
  8. pymc_extras/prior.py +3 -3
  9. pymc_extras/statespace/core/properties.py +276 -0
  10. pymc_extras/statespace/core/statespace.py +182 -45
  11. pymc_extras/statespace/filters/distributions.py +19 -34
  12. pymc_extras/statespace/filters/kalman_filter.py +13 -12
  13. pymc_extras/statespace/filters/kalman_smoother.py +2 -2
  14. pymc_extras/statespace/models/DFM.py +179 -168
  15. pymc_extras/statespace/models/ETS.py +177 -151
  16. pymc_extras/statespace/models/SARIMAX.py +149 -152
  17. pymc_extras/statespace/models/VARMAX.py +134 -145
  18. pymc_extras/statespace/models/__init__.py +8 -1
  19. pymc_extras/statespace/models/structural/__init__.py +30 -8
  20. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  21. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  22. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  23. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  24. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  25. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  26. pymc_extras/statespace/models/structural/core.py +397 -286
  27. pymc_extras/statespace/models/utilities.py +5 -20
  28. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
  29. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
  30. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  31. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,6 @@
1
1
  import functools as ft
2
2
  import logging
3
3
 
4
- from collections.abc import Sequence
5
4
  from itertools import pairwise
6
5
  from typing import Any
7
6
 
@@ -11,12 +10,28 @@ import xarray as xr
11
10
  from pytensor import Mode, Variable, config
12
11
  from pytensor import tensor as pt
13
12
 
14
- from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation
13
+ from pymc_extras.statespace.core.properties import (
14
+ Coord,
15
+ CoordInfo,
16
+ Data,
17
+ DataInfo,
18
+ Parameter,
19
+ ParameterInfo,
20
+ Shock,
21
+ ShockInfo,
22
+ State,
23
+ StateInfo,
24
+ SymbolicData,
25
+ SymbolicDataInfo,
26
+ SymbolicVariable,
27
+ SymbolicVariableInfo,
28
+ )
29
+ from pymc_extras.statespace.core.representation import PytensorRepresentation
30
+ from pymc_extras.statespace.core.statespace import PyMCStateSpace, _validate_property
15
31
  from pymc_extras.statespace.models.utilities import (
16
32
  add_tensors_by_dim_labels,
17
33
  conform_time_varying_and_time_invariant_matrices,
18
34
  join_tensors_by_dim_labels,
19
- make_default_coords,
20
35
  )
21
36
  from pymc_extras.statespace.utils.constants import (
22
37
  ALL_STATE_AUX_DIM,
@@ -32,54 +47,13 @@ class StructuralTimeSeries(PyMCStateSpace):
32
47
  r"""
33
48
  Structural Time Series Model
34
49
 
35
- The structural time series model, named by [1] and presented in statespace form in [2], is a framework for
36
- decomposing a univariate time series into level, trend, seasonal, and cycle components. It also admits the
37
- possibility of exogenous regressors. Unlike the SARIMAX framework, the time series is not assumed to be stationary.
50
+ A framework for decomposing a univariate time series into level, trend, seasonal, and cycle
51
+ components, as named by [1]_ and presented in state space form in [2]_.
38
52
 
39
- Parameters
40
- ----------
41
- ssm : PytensorRepresentation
42
- The state space representation containing system matrices.
43
- name : str
44
- Name of the model. If None, defaults to "StructuralTimeSeries".
45
- state_names : list[str]
46
- Names of the hidden states in the model.
47
- observed_state_names : list[str]
48
- Names of the observed variables.
49
- data_names : list[str]
50
- Names of data variables expected by the model.
51
- shock_names : list[str]
52
- Names of innovation/shock processes.
53
- param_names : list[str]
54
- Names of model parameters.
55
- exog_names : list[str]
56
- Names of exogenous variables.
57
- param_dims : dict[str, tuple[int]]
58
- Dimension specifications for parameters.
59
- coords : dict[str, Sequence]
60
- Coordinate specifications for the model.
61
- param_info : dict[str, dict[str, Any]]
62
- Information about parameters including shapes and constraints.
63
- data_info : dict[str, dict[str, Any]]
64
- Information about data variables.
65
- component_info : dict[str, dict[str, Any]]
66
- Information about model components.
67
- measurement_error : bool
68
- Whether the model includes measurement error.
69
- name_to_variable : dict[str, Variable]
70
- Mapping from parameter names to PyTensor variables.
71
- name_to_data : dict[str, Variable] | None, optional
72
- Mapping from data names to PyTensor variables. Default is None.
73
- verbose : bool, optional
74
- Whether to print model information. Default is True.
75
- filter_type : str, optional
76
- Type of Kalman filter to use. Default is "standard".
77
- mode : str | Mode | None, optional
78
- PyTensor compilation mode. Default is None.
79
-
80
- Notes
81
- -----
82
- The structural time series model decomposes a time series into interpretable components:
53
+ This class is not typically instantiated directly. Instead, use ``Component.build()`` to
54
+ construct a model from components combined with the ``+`` operator.
55
+
56
+ The model decomposes a time series into interpretable components:
83
57
 
84
58
  .. math::
85
59
 
@@ -94,10 +68,6 @@ class StructuralTimeSeries(PyMCStateSpace):
94
68
  - :math:`\xi_t` is the autoregressive component
95
69
  - :math:`\varepsilon_t` is the measurement error
96
70
 
97
- The model is built by combining individual components (e.g., LevelTrendComponent,
98
- TimeSeasonality, CycleComponent) using the addition operator. Each component
99
- contributes to the overall state space representation.
100
-
101
71
  Examples
102
72
  --------
103
73
  Create a model with trend and seasonal components:
@@ -108,7 +78,7 @@ class StructuralTimeSeries(PyMCStateSpace):
108
78
  import pymc as pm
109
79
  import pytensor.tensor as pt
110
80
 
111
- trend = st.LevelTrendComponent(order=2 innovations_order=1)
81
+ trend = st.LevelTrend(order=2, innovations_order=1)
112
82
  seasonal = st.TimeSeasonality(season_length=12, innovations=True)
113
83
  error = st.MeasurementError()
114
84
 
@@ -128,6 +98,14 @@ class StructuralTimeSeries(PyMCStateSpace):
128
98
  ss_mod.build_statespace_graph(data)
129
99
  idata = pm.sample()
130
100
 
101
+ See Also
102
+ --------
103
+ Component : Base class for structural time series components.
104
+ LevelTrend : Component for modeling level and trend.
105
+ TimeSeasonality : Component for seasonal effects.
106
+ Cycle : Component for cyclical effects.
107
+ Autoregressive : Component for autoregressive dynamics.
108
+
131
109
  References
132
110
  ----------
133
111
  .. [1] Harvey, A. C. (1989). Forecasting, structural time series models and the
@@ -140,49 +118,63 @@ class StructuralTimeSeries(PyMCStateSpace):
140
118
  self,
141
119
  ssm: PytensorRepresentation,
142
120
  name: str,
143
- state_names: list[str],
144
- observed_state_names: list[str],
145
- data_names: list[str],
146
- shock_names: list[str],
147
- param_names: list[str],
148
- exog_names: list[str],
149
- param_dims: dict[str, tuple[int]],
150
- coords: dict[str, Sequence],
151
- param_info: dict[str, dict[str, Any]],
152
- data_info: dict[str, dict[str, Any]],
121
+ coords_info: CoordInfo,
122
+ param_info: ParameterInfo,
123
+ data_info: DataInfo,
124
+ shock_info: ShockInfo,
125
+ state_info: StateInfo,
126
+ tensor_variable_info: SymbolicVariableInfo,
127
+ tensor_data_info: SymbolicDataInfo,
153
128
  component_info: dict[str, dict[str, Any]],
154
129
  measurement_error: bool,
155
- name_to_variable: dict[str, Variable],
156
- name_to_data: dict[str, Variable] | None = None,
157
130
  verbose: bool = True,
158
131
  filter_type: str = "standard",
159
132
  mode: str | Mode | None = None,
160
133
  ):
161
- name = "StructuralTimeSeries" if name is None else name
162
-
163
- self._name = name
164
- self._observed_state_names = observed_state_names
134
+ """
135
+ Initialize a StructuralTimeSeries model.
165
136
 
166
- k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog
167
- param_names, param_dims, param_info = self._add_inital_state_cov_to_properties(
168
- param_names, param_dims, param_info, k_states
169
- )
137
+ This constructor is typically called by ``Component.build()`` rather than directly.
170
138
 
171
- self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog)
172
- self._data_names = self._strip_data_names_if_unambiguous(data_names, k_endog)
173
- self._shock_names = self._strip_data_names_if_unambiguous(shock_names, k_endog)
174
- self._param_names = self._strip_data_names_if_unambiguous(param_names, k_endog)
175
- self._param_dims = param_dims
139
+ Parameters
140
+ ----------
141
+ ssm : PytensorRepresentation
142
+ The state space representation containing system matrices.
143
+ name : str
144
+ Name of the model. If None, defaults to "StructuralTimeSeries".
145
+ coords_info : CoordInfo
146
+ Coordinate specifications for model dimensions.
147
+ param_info : ParameterInfo
148
+ Information about model parameters including shapes and constraints.
149
+ data_info : DataInfo
150
+ Information about data variables expected by the model.
151
+ shock_info : ShockInfo
152
+ Information about innovation/shock processes.
153
+ state_info : StateInfo
154
+ Information about hidden and observed states.
155
+ tensor_variable_info : SymbolicVariableInfo
156
+ Mapping from parameter names to PyTensor symbolic variables.
157
+ tensor_data_info : SymbolicDataInfo
158
+ Mapping from data names to PyTensor symbolic variables.
159
+ component_info : dict[str, dict[str, Any]]
160
+ Information about model components used for state extraction.
161
+ measurement_error : bool
162
+ Whether the model includes measurement error.
163
+ verbose : bool, default True
164
+ Whether to print model information during construction.
165
+ filter_type : str, default "standard"
166
+ Type of Kalman filter to use.
167
+ mode : str | Mode | None, default None
168
+ PyTensor compilation mode.
169
+ """
170
+ self._name = name or "StructuralTimeSeries"
171
+ self.measurement_error = measurement_error
176
172
 
177
- default_coords = make_default_coords(self)
178
- coords.update(default_coords)
173
+ k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog
179
174
 
180
- self._coords = {
181
- k: self._strip_data_names_if_unambiguous(v, k_endog) for k, v in coords.items()
182
- }
183
- self._param_info = param_info.copy()
184
- self._data_info = data_info.copy()
185
- self.measurement_error = measurement_error
175
+ self._init_info_objects(
176
+ param_info, data_info, shock_info, state_info, coords_info, k_states, k_endog
177
+ )
186
178
 
187
179
  super().__init__(
188
180
  k_endog,
@@ -193,30 +185,75 @@ class StructuralTimeSeries(PyMCStateSpace):
193
185
  measurement_error=measurement_error,
194
186
  mode=mode,
195
187
  )
188
+
189
+ self._tensor_variable_info = tensor_variable_info
190
+ self._tensor_data_info = tensor_data_info
191
+ self._component_info = component_info.copy()
192
+ self._exog_names = data_info.exogenous_names
193
+ self._needs_exog_data = data_info.needs_exogenous_data
194
+
195
+ self._init_ssm(ssm, k_posdef)
196
+
197
+ def _init_info_objects(
198
+ self,
199
+ param_info: ParameterInfo,
200
+ data_info: DataInfo,
201
+ shock_info: ShockInfo,
202
+ state_info: StateInfo,
203
+ coords_info: CoordInfo,
204
+ k_states: int,
205
+ k_endog: int,
206
+ ) -> None:
207
+ """Initialize all info objects and set observed state names."""
208
+ self._observed_state_names = state_info.observed_state_names
209
+
210
+ param_names, param_dims, param_info = self._add_inital_state_cov_to_properties(
211
+ param_info, k_states
212
+ )
213
+ self._param_dims = param_dims
214
+
215
+ self._param_info = param_info
216
+ self._data_info = data_info
217
+ self._shock_info = shock_info
218
+ self._state_info = state_info
219
+
220
+ # Stripped names must be set before default_coords_from_model (which accesses state_names)
221
+ self._init_stripped_names(k_endog)
222
+
223
+ default_coords = coords_info.default_coords_from_model(self)
224
+ self._coords_info = coords_info.merge(default_coords)
225
+
226
+ def _init_stripped_names(self, k_endog: int) -> None:
227
+ """Strip data suffixes from names when k_endog == 1 for cleaner output."""
228
+
229
+ def strip(names):
230
+ return self._strip_data_names_if_unambiguous(names, k_endog)
231
+
232
+ self._state_names = strip(self._state_info.unobserved_state_names)
233
+ self._data_names = strip([d.name for d in self._data_info if not d.is_exogenous])
234
+ self._shock_names = strip(self._shock_info.names)
235
+ self._param_names = strip(self._param_info.names)
236
+
237
+ def _init_ssm(self, ssm: PytensorRepresentation, k_posdef: int) -> None:
238
+ """Initialize state space model representation."""
196
239
  self.ssm = ssm.copy()
197
240
 
198
241
  if k_posdef == 0:
199
- # If there is no randomness in the model, add dummy matrices to the representation to avoid errors
200
- # when we go to construct random variables from the matrices
201
242
  self.ssm.k_posdef = self.k_posdef
202
243
  self.ssm.shapes["state_cov"] = (1, 1, 1)
203
244
  self.ssm["state_cov"] = pt.zeros((1, 1, 1))
204
-
205
245
  self.ssm.shapes["selection"] = (1, self.k_states, 1)
206
246
  self.ssm["selection"] = pt.zeros((1, self.k_states, 1))
207
247
 
208
- self._component_info = component_info.copy()
209
-
210
- self._name_to_variable = name_to_variable.copy()
211
- self._name_to_data = name_to_data.copy()
212
-
213
- self._exog_names = exog_names.copy()
214
- self._needs_exog_data = len(exog_names) > 0
215
-
216
248
  P0 = self.make_and_register_variable("P0", shape=(self.k_states, self.k_states))
217
249
  self.ssm["initial_state_cov"] = P0
218
250
 
219
- def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int):
251
+ def _populate_properties(self) -> None:
252
+ # The base class method needs to be overridden because we directly set properties in
253
+ # the __init__ method.
254
+ pass
255
+
256
+ def _strip_data_names_if_unambiguous(self, names: list[str] | tuple[str, ...], k_endog: int):
220
257
  """
221
258
  State names from components should always be of the form name[data_name], in the case that the component is
222
259
  associated with multiple observed states. Not doing so leads to ambiguity -- we might have two level states,
@@ -226,62 +263,40 @@ class StructuralTimeSeries(PyMCStateSpace):
226
263
  the state name. This is a bit cleaner.
227
264
  """
228
265
  if k_endog == 1:
229
- [data_name] = self.observed_states
230
- return [
266
+ [data_name] = self._observed_state_names
267
+ return tuple(
231
268
  name.replace(f"[{data_name}]", "") if isinstance(name, str) else name
232
269
  for name in names
233
- ]
270
+ )
234
271
 
235
272
  else:
236
273
  return names
237
274
 
238
- @staticmethod
239
- def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states):
240
- param_names += ["P0"]
241
- param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM)
242
- param_info["P0"] = {
243
- "shape": (k_states, k_states),
244
- "constraints": "Positive semi-definite",
245
- "dims": param_dims["P0"],
246
- }
247
-
248
- return param_names, param_dims, param_info
249
-
250
275
  @property
251
- def param_names(self):
252
- return self._param_names
253
-
254
- @property
255
- def data_names(self) -> list[str]:
256
- return self._data_names
257
-
258
- @property
259
- def state_names(self):
276
+ def state_names(self) -> tuple[str, ...]:
277
+ """Return stripped state names (without [data_name] suffix when k_endog == 1)."""
260
278
  return self._state_names
261
279
 
262
280
  @property
263
- def observed_states(self):
264
- return self._observed_state_names
265
-
266
- @property
267
- def shock_names(self):
281
+ def shock_names(self) -> tuple[str, ...]:
282
+ """Return stripped shock names (without [data_name] suffix when k_endog == 1)."""
268
283
  return self._shock_names
269
284
 
270
- @property
271
- def param_dims(self):
272
- return self._param_dims
273
-
274
- @property
275
- def coords(self) -> dict[str, Sequence]:
276
- return self._coords
285
+ @staticmethod
286
+ def _add_inital_state_cov_to_properties(param_info, k_states):
287
+ initial_state_cov_param = Parameter(
288
+ name="P0",
289
+ shape=(k_states, k_states),
290
+ dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM),
291
+ constraints="Positive semi-definite",
292
+ )
277
293
 
278
- @property
279
- def param_info(self) -> dict[str, dict[str, Any]]:
280
- return self._param_info
294
+ if param_info is not None:
295
+ param_info = param_info.add(initial_state_cov_param)
296
+ else:
297
+ param_info = ParameterInfo(parameters=(initial_state_cov_param,))
281
298
 
282
- @property
283
- def data_info(self) -> dict[str, dict[str, Any]]:
284
- return self._data_info
299
+ return param_info.names, [p.dims for p in param_info], param_info
285
300
 
286
301
  def make_symbolic_graph(self) -> None:
287
302
  """
@@ -401,7 +416,7 @@ class StructuralTimeSeries(PyMCStateSpace):
401
416
  new_idata.coords.update({state_dim: new_state_names})
402
417
  return new_idata
403
418
 
404
- var_names = list(idata.data_vars.keys())
419
+ var_names: list[str] = list(idata.data_vars.keys()) # type: ignore[arg-type]
405
420
  is_latent = [idata[name].shape[-1] == self.k_states for name in var_names]
406
421
  new_state_names = self._get_subcomponent_names()
407
422
 
@@ -409,7 +424,7 @@ class StructuralTimeSeries(PyMCStateSpace):
409
424
  dropped_vars = set(var_names) - set(latent_names)
410
425
  if len(dropped_vars) > 0:
411
426
  _log.warning(
412
- f"Variables {', '.join(dropped_vars)} do not contain all hidden states (their last dimension "
427
+ f"Variables {', '.join(sorted(dropped_vars))} do not contain all hidden states (their last dimension "
413
428
  f"is not {self.k_states}). They will not be present in the modified idata."
414
429
  )
415
430
  if len(dropped_vars) == len(var_names):
@@ -445,19 +460,12 @@ class Component:
445
460
  k_posdef : int
446
461
  Rank of the state covariance matrix, or the number of sources of innovations
447
462
  in the component model.
448
- state_names : list[str] | None, optional
449
- Names of the hidden states. If None, defaults to empty list.
450
- observed_state_names : list[str] | None, optional
451
- Names of the observed states associated with this component. Must have the same
452
- length as k_endog. If None, defaults to empty list.
453
- data_names : list[str] | None, optional
454
- Names of data variables expected by the component. If None, defaults to empty list.
455
- shock_names : list[str] | None, optional
456
- Names of innovation/shock processes. If None, defaults to empty list.
457
- param_names : list[str] | None, optional
458
- Names of component parameters. If None, defaults to empty list.
459
- exog_names : list[str] | None, optional
460
- Names of exogenous variables. If None, defaults to empty list.
463
+ base_state_names : list[str] | None, optional
464
+ Base names of hidden states, before any transformations by set_states().
465
+ Subclasses typically transform these (e.g., adding suffixes). If None, defaults to empty list.
466
+ base_observed_state_names : list[str] | None, optional
467
+ Base names of observed states, before any transformations by set_states().
468
+ If None, defaults to empty list.
461
469
  representation : PytensorRepresentation | None, optional
462
470
  Pre-existing state space representation. If None, creates a new one.
463
471
  measurement_error : bool, optional
@@ -484,7 +492,7 @@ class Component:
484
492
 
485
493
  from pymc_extras.statespace import structural as st
486
494
 
487
- trend = st.LevelTrendComponent(order=2, innovations_order=1)
495
+ trend = st.LevelTrend(order=2, innovations_order=1)
488
496
  seasonal = st.TimeSeasonality(season_length=12, innovations=True)
489
497
  model = (trend + seasonal).build()
490
498
 
@@ -493,10 +501,10 @@ class Component:
493
501
  See Also
494
502
  --------
495
503
  StructuralTimeSeries : The complete model class that combines components.
496
- LevelTrendComponent : Component for modeling level and trend.
504
+ LevelTrend : Component for modeling level and trend.
497
505
  TimeSeasonality : Component for seasonal effects.
498
- CycleComponent : Component for cyclical effects.
499
- RegressionComponent : Component for regression effects.
506
+ Cycle : Component for cyclical effects.
507
+ Regression : Component for regression effects.
500
508
  """
501
509
 
502
510
  def __init__(
@@ -505,12 +513,8 @@ class Component:
505
513
  k_endog,
506
514
  k_states,
507
515
  k_posdef,
508
- state_names=None,
509
- observed_state_names=None,
510
- data_names=None,
511
- shock_names=None,
512
- param_names=None,
513
- exog_names=None,
516
+ base_state_names=None,
517
+ base_observed_state_names=None,
514
518
  representation: PytensorRepresentation | None = None,
515
519
  measurement_error=False,
516
520
  combine_hidden_states=True,
@@ -519,53 +523,202 @@ class Component:
519
523
  share_states: bool = False,
520
524
  ):
521
525
  self.name = name
522
- self.k_endog = k_endog
523
- self.k_states = k_states
524
526
  self.share_states = share_states
525
- self.k_posdef = k_posdef
526
527
  self.measurement_error = measurement_error
527
528
 
528
- self.state_names = list(state_names) if state_names is not None else []
529
- self.observed_state_names = (
530
- list(observed_state_names) if observed_state_names is not None else []
529
+ base_state_names = list(base_state_names) if base_state_names is not None else []
530
+ base_observed_state_names = (
531
+ list(base_observed_state_names) if base_observed_state_names is not None else []
531
532
  )
532
- self.data_names = list(data_names) if data_names is not None else []
533
- self.shock_names = list(shock_names) if shock_names is not None else []
534
- self.param_names = list(param_names) if param_names is not None else []
535
- self.exog_names = list(exog_names) if exog_names is not None else []
536
533
 
537
- self.needs_exog_data = len(self.exog_names) > 0
538
- self.coords = {}
539
- self.param_dims = {}
534
+ self._k_posdef = k_posdef
535
+ self._k_endog = len(base_observed_state_names) or k_endog
536
+ self._k_states = k_states
537
+ self.base_state_names = base_state_names
538
+ self.base_observed_state_names = base_observed_state_names
540
539
 
541
- self.param_info = {}
542
- self.data_info = {}
540
+ self._init_ssm(representation, k_endog, k_states, k_posdef)
543
541
 
544
- self.param_counts = {}
542
+ self._tensor_variable_info = SymbolicVariableInfo()
543
+ self._tensor_data_info = SymbolicDataInfo()
544
+
545
+ if not component_from_sum:
546
+ self.populate_component_properties()
547
+ self.make_symbolic_graph()
548
+
549
+ self._component_info = {
550
+ self.name: {
551
+ "k_states": k_states,
552
+ "k_endog": k_endog,
553
+ "k_posdef": k_posdef,
554
+ "observed_state_names": self._state_info.observed_state_names,
555
+ "combine_hidden_states": combine_hidden_states,
556
+ "obs_state_idx": obs_state_idxs,
557
+ "share_states": self.share_states,
558
+ }
559
+ }
545
560
 
561
+ def _init_ssm(
562
+ self,
563
+ representation: PytensorRepresentation | None,
564
+ k_endog: int,
565
+ k_states: int,
566
+ k_posdef: int,
567
+ ) -> None:
568
+ """Initialize state space model representation."""
546
569
  if representation is None:
547
570
  self.ssm = PytensorRepresentation(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)
548
571
  else:
549
572
  self.ssm = representation
550
573
 
551
- self._name_to_variable = {}
552
- self._name_to_data = {}
574
+ def populate_component_properties(self) -> None:
575
+ self._set_states()
576
+ self._set_parameters()
577
+ self._set_shocks()
578
+ self._set_data_info()
579
+ self._set_coords()
553
580
 
554
- if not component_from_sum:
555
- self.populate_component_properties()
556
- self.make_symbolic_graph()
581
+ def set_states(self) -> State | tuple[State, ...] | None:
582
+ """
583
+ Set default state specification based on number of states and endogenous variables in the component.
557
584
 
558
- self._component_info = {
559
- self.name: {
560
- "k_states": self.k_states,
561
- "k_endog": self.k_endog,
562
- "k_posdef": self.k_posdef,
563
- "observed_state_names": self.observed_state_names,
564
- "combine_hidden_states": combine_hidden_states,
565
- "obs_state_idx": obs_state_idxs,
566
- "share_states": self.share_states,
567
- }
568
- }
585
+ It is encouraged to override this method.
586
+ """
587
+ state_names = self.base_state_names or [i for i in range(self.k_states or 0)]
588
+ observed_state_names = self.base_observed_state_names or [
589
+ i for i in range(self._k_endog or 0)
590
+ ]
591
+
592
+ hidden_states = [
593
+ State(name=name, observed=False, shared=self.share_states) for name in state_names
594
+ ]
595
+ observed_states = [
596
+ State(name=name, observed=True, shared=self.share_states)
597
+ for name in observed_state_names
598
+ ]
599
+ return *hidden_states, *observed_states
600
+
601
+ def _set_states(self) -> None:
602
+ states = self.set_states()
603
+ _validate_property(states, "states", State)
604
+ if isinstance(states, State):
605
+ states = (states,)
606
+ self._state_info = StateInfo(states=states)
607
+
608
+ def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
609
+ """
610
+ Set component parameter specifications. Since different component types will require different specifications,
611
+ you must be override this method.
612
+ """
613
+ return
614
+
615
+ def _set_parameters(self) -> None:
616
+ params = self.set_parameters()
617
+ _validate_property(params, "parameters", Parameter)
618
+ if isinstance(params, Parameter):
619
+ params = (params,)
620
+ self._param_info = ParameterInfo(parameters=params)
621
+
622
+ def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
623
+ """
624
+ Set default shock specifications based on the number of sources of innovations in the component.
625
+
626
+ It is encouraged to override this method.
627
+ """
628
+ return tuple(Shock(name=f"shock_{name}") for name in range(self.k_posdef or 0))
629
+
630
+ def _set_shocks(self) -> None:
631
+ shocks = self.set_shocks()
632
+ _validate_property(shocks, "shocks", Shock)
633
+ if isinstance(shocks, Shock):
634
+ shocks = (shocks,)
635
+ self._shock_info = ShockInfo(shocks=shocks)
636
+
637
+ def set_data_info(self) -> Data | tuple[Data, ...] | None:
638
+ """
639
+ Set default data specifications. Since different component types will require different specifications you must be override this method.
640
+ """
641
+ return
642
+
643
+ def _set_data_info(self) -> None:
644
+ data_info = self.set_data_info()
645
+ _validate_property(data_info, "data_info", Data)
646
+ if isinstance(data_info, Data):
647
+ data_info = (data_info,)
648
+ self._data_info = DataInfo(data=data_info)
649
+
650
+ def set_coords(self) -> Coord | tuple[Coord, ...] | None:
651
+ """
652
+ Set default coordinate specifications. Since different component types will require different specifications you must be override this method.
653
+ """
654
+ return
655
+
656
+ def _set_coords(self) -> None:
657
+ coords = self.set_coords()
658
+ _validate_property(coords, "coords", Coord)
659
+ if isinstance(coords, Coord):
660
+ coords = (coords,)
661
+ self._coords_info = CoordInfo(coords=coords)
662
+
663
+ @property
664
+ def state_names(self):
665
+ return self._state_info.unobserved_state_names
666
+
667
+ @property
668
+ def observed_state_names(self):
669
+ return self._state_info.observed_state_names
670
+
671
+ @property
672
+ def param_names(self):
673
+ return self._param_info.names
674
+
675
+ @property
676
+ def param_info(self):
677
+ return self._param_info
678
+
679
+ @property
680
+ def shock_names(self):
681
+ return self._shock_info.names
682
+
683
+ @property
684
+ def data_names(self):
685
+ return [data.name for data in self._data_info if not data.is_exogenous]
686
+
687
+ @property
688
+ def exog_names(self):
689
+ return self._data_info.exogenous_names
690
+
691
+ @property
692
+ def coords(self):
693
+ return self._coords_info.to_dict()
694
+
695
+ @property
696
+ def param_dims(self):
697
+ return {param.name: param.dims for param in self._param_info if param.dims is not None}
698
+
699
+ @property
700
+ def needs_exog_data(self):
701
+ return self._data_info.needs_exogenous_data
702
+
703
+ @property
704
+ def k_states(self):
705
+ return self._k_states
706
+
707
+ @property
708
+ def k_endog(self):
709
+ return self._k_endog
710
+
711
+ @property
712
+ def k_posdef(self):
713
+ return self._k_posdef
714
+
715
+ @property
716
+ def _name_to_variable(self):
717
+ return self._tensor_variable_info.to_dict()
718
+
719
+ @property
720
+ def _name_to_data(self):
721
+ return self._tensor_data_info.to_dict()
569
722
 
570
723
  def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
571
724
  r"""
@@ -595,20 +748,21 @@ class Component:
595
748
  An error is raised if the provided name has already been registered, or if the name is not present in the
596
749
  ``param_names`` property.
597
750
  """
598
- if name not in self.param_names:
751
+ if name not in self._param_info:
599
752
  raise ValueError(
600
753
  f"{name} is not a model parameter. All placeholder variables should correspond to model "
601
754
  f"parameters."
602
755
  )
603
756
 
604
- if name in self._name_to_variable.keys():
757
+ if name in self._tensor_variable_info:
605
758
  raise ValueError(
606
759
  f"{name} is already a registered placeholder variable with shape "
607
- f"{self._name_to_variable[name].type.shape}"
760
+ f"{self._tensor_variable_info[name].symbolic_variable.type.shape}"
608
761
  )
609
762
 
610
763
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
611
- self._name_to_variable[name] = placeholder
764
+ tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
765
+ self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
612
766
  return placeholder
613
767
 
614
768
  def make_and_register_data(self, name, shape, dtype=floatX) -> Variable:
@@ -632,37 +786,34 @@ class Component:
632
786
  An error is raised if the provided name has already been registered, or if the name is not present in the
633
787
  ``data_names`` property.
634
788
  """
635
- if name not in self.data_names:
789
+ if name not in self._data_info:
636
790
  raise ValueError(
637
791
  f"{name} is not a model parameter. All placeholder variables should correspond to model "
638
792
  f"parameters."
639
793
  )
640
794
 
641
- if name in self._name_to_data.keys():
795
+ if name in self._tensor_data_info:
642
796
  raise ValueError(
643
797
  f"{name} is already a registered placeholder variable with shape "
644
- f"{self._name_to_data[name].type.shape}"
798
+ f"{self._tensor_data_info[name].symbolic_data.type.shape}"
645
799
  )
646
800
 
647
801
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
648
- self._name_to_data[name] = placeholder
802
+ tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
803
+ tensor_data_info = SymbolicDataInfo(symbolic_data=(tensor_data,))
804
+ self._tensor_data_info = self._tensor_data_info.merge(tensor_data_info)
649
805
  return placeholder
650
806
 
651
807
  def make_symbolic_graph(self) -> None:
652
808
  raise NotImplementedError
653
809
 
654
- def populate_component_properties(self):
655
- raise NotImplementedError
656
-
657
810
  def _get_combined_shapes(self, other):
658
811
  k_states = self.k_states + other.k_states
659
812
  k_posdef = self.k_posdef + other.k_posdef
660
813
 
661
814
  # To count endog states, we have to count unique names between the two components.
662
- combined_states = self._combine_property(
663
- other, "observed_state_names", allow_duplicates=False
664
- )
665
- k_endog = len(combined_states)
815
+ combined_states = self._state_info.merge(other._state_info, overwrite_duplicates=True)
816
+ k_endog = len(combined_states.observed_state_names)
666
817
 
667
818
  return k_states, k_posdef, k_endog
668
819
 
@@ -698,7 +849,11 @@ class Component:
698
849
  state_intercept.name = c.name
699
850
 
700
851
  obs_intercept = add_tensors_by_dim_labels(
701
- d, o_d, labels=self_observed_states, other_labels=other_observed_states, labeled_axis=-1
852
+ d,
853
+ o_d,
854
+ labels=list(self_observed_states),
855
+ other_labels=list(other_observed_states),
856
+ labeled_axis=-1,
702
857
  )
703
858
  obs_intercept.name = d.name
704
859
 
@@ -714,8 +869,8 @@ class Component:
714
869
 
715
870
  design = join_tensors_by_dim_labels(
716
871
  *conform_time_varying_and_time_invariant_matrices(Z, o_Z),
717
- labels=self_observed_states,
718
- other_labels=other_observed_states,
872
+ labels=list(self_observed_states),
873
+ other_labels=list(other_observed_states),
719
874
  labeled_axis=-2,
720
875
  join_axis=-1,
721
876
  )
@@ -734,8 +889,8 @@ class Component:
734
889
  obs_cov = add_tensors_by_dim_labels(
735
890
  H,
736
891
  o_H,
737
- labels=self_observed_states,
738
- other_labels=other_observed_states,
892
+ labels=list(self_observed_states),
893
+ other_labels=list(other_observed_states),
739
894
  labeled_axis=(-1, -2),
740
895
  )
741
896
  obs_cov.name = H.name
@@ -760,31 +915,6 @@ class Component:
760
915
 
761
916
  return new_ssm
762
917
 
763
- def _combine_property(self, other, name, allow_duplicates=True):
764
- self_prop = getattr(self, name)
765
- other_prop = getattr(other, name)
766
-
767
- if not isinstance(self_prop, type(other_prop)):
768
- raise TypeError(
769
- f"Property {name} of {self} and {other} are not the same and cannot be combined. Found "
770
- f"{type(self_prop)} for {self} and {type(other_prop)} for {other}'"
771
- )
772
-
773
- if not isinstance(self_prop, list | dict):
774
- raise TypeError(
775
- f"All component properties are expected to be lists or dicts, but found {type(self_prop)}"
776
- f"for property {name} of {self} and {type(other_prop)} for {other}'"
777
- )
778
-
779
- if isinstance(self_prop, list) and allow_duplicates:
780
- return self_prop + other_prop
781
- elif isinstance(self_prop, list) and not allow_duplicates:
782
- return self_prop + [x for x in other_prop if x not in self_prop]
783
- elif isinstance(self_prop, dict):
784
- new_prop = self_prop.copy()
785
- new_prop.update(other_prop)
786
- return new_prop
787
-
788
918
  def _combine_component_info(self, other):
789
919
  combined_info = {}
790
920
  for key, value in self._component_info.items():
@@ -807,22 +937,14 @@ class Component:
807
937
  return name
808
938
 
809
939
  def __add__(self, other):
810
- state_names = self._combine_property(other, "state_names")
811
- data_names = self._combine_property(other, "data_names")
812
- observed_state_names = self._combine_property(
813
- other, "observed_state_names", allow_duplicates=False
814
- )
815
-
816
- param_names = self._combine_property(other, "param_names")
817
- shock_names = self._combine_property(other, "shock_names")
818
- param_info = self._combine_property(other, "param_info")
819
- data_info = self._combine_property(other, "data_info")
820
- param_dims = self._combine_property(other, "param_dims")
821
- coords = self._combine_property(other, "coords")
822
- exog_names = self._combine_property(other, "exog_names")
823
-
824
- _name_to_variable = self._combine_property(other, "_name_to_variable")
825
- _name_to_data = self._combine_property(other, "_name_to_data")
940
+ param_info = self._param_info.merge(other._param_info)
941
+ data_info = self._data_info.merge(other._data_info)
942
+ shock_info = self._shock_info.merge(other._shock_info)
943
+ state_info = self._state_info.merge(other._state_info, overwrite_duplicates=True)
944
+ coords_info = self._coords_info.merge(other._coords_info)
945
+ observed_state_names = state_info.observed_state_names
946
+ tensor_variable_info = self._tensor_variable_info.merge(other._tensor_variable_info)
947
+ tensor_data_info = self._tensor_data_info.merge(other._tensor_data_info)
826
948
 
827
949
  measurement_error = any([self.measurement_error, other.measurement_error])
828
950
 
@@ -835,7 +957,7 @@ class Component:
835
957
  k_endog=k_endog,
836
958
  k_states=k_states,
837
959
  k_posdef=k_posdef,
838
- observed_state_names=observed_state_names,
960
+ base_observed_state_names=list(observed_state_names),
839
961
  measurement_error=measurement_error,
840
962
  representation=ssm,
841
963
  component_from_sum=True,
@@ -844,19 +966,13 @@ class Component:
844
966
  new_comp.name = new_comp._make_combined_name()
845
967
 
846
968
  names_and_props = [
847
- ("state_names", state_names),
848
- ("observed_state_names", observed_state_names),
849
- ("data_names", data_names),
850
- ("param_names", param_names),
851
- ("shock_names", shock_names),
852
- ("param_dims", param_dims),
853
- ("coords", coords),
854
- ("param_dims", param_dims),
855
- ("param_info", param_info),
856
- ("data_info", data_info),
857
- ("exog_names", exog_names),
858
- ("_name_to_variable", _name_to_variable),
859
- ("_name_to_data", _name_to_data),
969
+ ("_coords_info", coords_info),
970
+ ("_param_info", param_info),
971
+ ("_data_info", data_info),
972
+ ("_shock_info", shock_info),
973
+ ("_state_info", state_info),
974
+ ("_tensor_variable_info", tensor_variable_info),
975
+ ("_tensor_data_info", tensor_data_info),
860
976
  ]
861
977
 
862
978
  for prop, value in names_and_props:
@@ -899,20 +1015,15 @@ class Component:
899
1015
  return StructuralTimeSeries(
900
1016
  self.ssm,
901
1017
  name=name,
902
- state_names=self.state_names,
903
- observed_state_names=self.observed_state_names,
904
- data_names=self.data_names,
905
- shock_names=self.shock_names,
906
- param_names=self.param_names,
907
- param_dims=self.param_dims,
908
- coords=self.coords,
909
- param_info=self.param_info,
910
- data_info=self.data_info,
1018
+ coords_info=self._coords_info,
1019
+ param_info=self._param_info,
1020
+ data_info=self._data_info,
1021
+ shock_info=self._shock_info,
1022
+ state_info=self._state_info,
1023
+ tensor_variable_info=self._tensor_variable_info,
1024
+ tensor_data_info=self._tensor_data_info,
911
1025
  component_info=self._component_info,
912
1026
  measurement_error=self.measurement_error,
913
- exog_names=self.exog_names,
914
- name_to_variable=self._name_to_variable,
915
- name_to_data=self._name_to_data,
916
1027
  filter_type=filter_type,
917
1028
  verbose=verbose,
918
1029
  mode=mode,