pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/__init__.py +8 -1
  7. pymc_extras/inference/dadvi/__init__.py +0 -0
  8. pymc_extras/inference/dadvi/dadvi.py +351 -0
  9. pymc_extras/inference/fit.py +5 -0
  10. pymc_extras/inference/laplace_approx/find_map.py +32 -47
  11. pymc_extras/inference/laplace_approx/idata.py +27 -6
  12. pymc_extras/inference/laplace_approx/laplace.py +24 -6
  13. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  14. pymc_extras/inference/pathfinder/idata.py +517 -0
  15. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  16. pymc_extras/model/marginal/graph_analysis.py +2 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  21. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  22. pymc_extras/statespace/filters/utilities.py +2 -5
  23. pymc_extras/statespace/models/DFM.py +834 -0
  24. pymc_extras/statespace/models/ETS.py +190 -198
  25. pymc_extras/statespace/models/SARIMAX.py +9 -21
  26. pymc_extras/statespace/models/VARMAX.py +22 -74
  27. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  28. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  29. pymc_extras/statespace/models/utilities.py +7 -0
  30. pymc_extras/statespace/utils/constants.py +3 -1
  31. pymc_extras/utils/model_equivalence.py +2 -2
  32. pymc_extras/utils/prior.py +10 -14
  33. pymc_extras/utils/spline.py +4 -10
  34. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  35. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
  36. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  37. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode
9
9
  from pytensor.tensor.slinalg import solve_discrete_lyapunov
10
10
 
11
11
  from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
12
- from pymc_extras.statespace.models.utilities import make_default_coords
12
+ from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
13
13
  from pymc_extras.statespace.utils.constants import (
14
14
  ALL_STATE_AUX_DIM,
15
15
  ALL_STATE_DIM,
@@ -20,213 +20,209 @@ from pymc_extras.statespace.utils.constants import (
20
20
 
21
21
 
22
22
  class BayesianETS(PyMCStateSpace):
23
- def __init__(
24
- self,
25
- order: tuple[str, str, str] | None = None,
26
- endog_names: str | list[str] | None = None,
27
- k_endog: int = 1,
28
- trend: bool = True,
29
- damped_trend: bool = False,
30
- seasonal: bool = False,
31
- seasonal_periods: int | None = None,
32
- measurement_error: bool = False,
33
- use_transformed_parameterization: bool = False,
34
- dense_innovation_covariance: bool = False,
35
- stationary_initialization: bool = False,
36
- initialization_dampening: float = 0.8,
37
- filter_type: str = "standard",
38
- verbose: bool = True,
39
- mode: str | Mode | None = None,
40
- ):
41
- r"""
42
- Exponential Smoothing State Space Model
23
+ r"""
24
+ Exponential Smoothing State Space Model
25
+
26
+ This class can represent a subset of exponential smoothing state space models, specifically those with additive
27
+ errors. Following .. [1], The general form of the model is:
28
+
29
+ .. math::
30
+
31
+ \begin{align}
32
+ y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
33
+ \epsilon_t &\sim N(0, \sigma)
34
+ \end{align}
35
+
36
+ where :math:`l_t` is the level component, :math:`b_t` is the trend component, and :math:`s_t` is the seasonal
37
+ component. These components can be included or excluded, leading to different model specifications. The following
38
+ models are possible:
39
+
40
+ * `ETS(A,N,N)`: Simple exponential smoothing
41
+
42
+ .. math::
43
+
44
+ \begin{align}
45
+ y_t &= l_{t-1} + \epsilon_t \\
46
+ l_t &= l_{t-1} + \alpha \epsilon_t
47
+ \end{align}
48
+
49
+ Where :math:`\alpha \in [0, 1]` is a mixing parameter between past observations and current innovations.
50
+ These equations arise by starting from the "component form":
51
+
52
+ .. math::
53
+
54
+ \begin{align}
55
+ \hat{y}_{t+1 | t} &= l_t \\
56
+ l_t &= \alpha y_t + (1 - \alpha) l_{t-1} \\
57
+ &= l_{t-1} + \alpha (y_t - l_{t-1})
58
+ &= l_{t-1} + \alpha \epsilon_t
59
+ \end{align}
60
+
61
+ Where $\epsilon_t$ are the forecast errors, assumed to be IID mean zero and normally distributed. The role of
62
+ :math:`\alpha` is clearest in the second line. The level of the time series at each time is a mixture of
63
+ :math:`\alpha` percent of the incoming data, and :math:`1 - \alpha` percent of the previous level. Recursive
64
+ substitution reveals that the level is a weighted composite of all previous observations; thus the name
65
+ "Exponential Smoothing".
66
+
67
+ Additional supposed specifications include:
68
+
69
+ * `ETS(A,A,N)`: Holt's linear trend method
70
+
71
+ .. math::
72
+
73
+ \begin{align}
74
+ y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
75
+ l_t &= l_{t-1} + b_{t-1} + \alpha \epsilon_t \\
76
+ b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t
77
+ \end{align}
78
+
79
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
80
+
81
+ * `ETS(A,N,A)`: Additive seasonal method
82
+
83
+ .. math::
43
84
 
44
- This class can represent a subset of exponential smoothing state space models, specifically those with additive
45
- errors. Following .. [1], The general form of the model is:
85
+ \begin{align}
86
+ y_t &= l_{t-1} + s_{t-m} + \epsilon_t \\
87
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
88
+ s_t &= s_{t-m} + (1 - \alpha)\gamma^\star \epsilon_t
89
+ \end{align}
90
+
91
+ [1]_ also consider an alternative parameterization with :math:`\gamma = (1 - \alpha) \gamma^\star`.
92
+
93
+ * `ETS(A,A,A)`: Additive Holt-Winters method
46
94
 
47
95
  .. math::
48
96
 
49
97
  \begin{align}
50
98
  y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
51
- \epsilon_t &\sim N(0, \sigma)
99
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
100
+ b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t \\
101
+ s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
52
102
  \end{align}
53
103
 
54
- where :math:`l_t` is the level component, :math:`b_t` is the trend component, and :math:`s_t` is the seasonal
55
- component. These components can be included or excluded, leading to different model specifications. The following
56
- models are possible:
104
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
105
+ :math:`\gamma = (1 - \alpha) \gamma^\star`.
106
+
107
+ * `ETS(A, Ad, N)`: Dampened trend method
57
108
 
58
- * `ETS(A,N,N)`: Simple exponential smoothing
109
+ .. math::
110
+
111
+ \begin{align}
112
+ y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
113
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
114
+ b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t
115
+ \end{align}
116
+
117
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
59
118
 
60
- .. math::
119
+ * `ETS(A, Ad, A)`: Dampened trend with seasonal method
61
120
 
62
- \begin{align}
63
- y_t &= l_{t-1} + \epsilon_t \\
64
- l_t &= l_{t-1} + \alpha \epsilon_t
65
- \end{align}
121
+ .. math::
66
122
 
67
- Where :math:`\alpha \in [0, 1]` is a mixing parameter between past observations and current innovations.
68
- These equations arise by starting from the "component form":
123
+ \begin{align}
124
+ y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
125
+ l_t &= l_{t-1} + \alpha \epsilon_t \\
126
+ b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t \\
127
+ s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
128
+ \end{align}
69
129
 
70
- .. math::
130
+ [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
131
+ :math:`\gamma = (1 - \alpha) \gamma^\star`.
132
+
133
+
134
+ Parameters
135
+ ----------
136
+ order: tuple of string, Optional
137
+ The exponential smoothing "order". This is a tuple of three strings, each of which should be one of 'A', 'Ad',
138
+ or 'N'.
139
+ If provided, the model will be initialized from the given order, and the `trend`, `damped_trend`, and `seasonal`
140
+ arguments will be ignored.
141
+ endog_names: str or list of str
142
+ Names associated with observed states. If a list, the length should be equal to the number of time series
143
+ to be estimated.
144
+ trend: bool
145
+ Whether to include a trend component. Setting ``trend=True`` is equivalent to ``order[1] == 'A'``.
146
+ damped_trend: bool
147
+ Whether to include a damping parameter on the trend component. Ignored if `trend` is `False`. Setting
148
+ ``trend=True`` and ``damped_trend=True`` is equivalent to order[1] == 'Ad'.
149
+ seasonal: bool
150
+ Whether to include a seasonal component. Setting ``seasonal=True`` is equivalent to ``order[2] = 'A'``.
151
+ seasonal_periods: int
152
+ The number of periods in a complete seasonal cycle. Ignored if `seasonal` is `False`
153
+ (or if ``order[2] == "N"``)
154
+ measurement_error: bool
155
+ Whether to include a measurement error term in the model. Default is `False`.
156
+ use_transformed_parameterization: bool, default False
157
+ If true, use the :math:`\alpha, \beta, \gamma` parameterization, otherwise use the :math:`\alpha, \beta^\star,
158
+ \gamma^\star` parameterization. This will change the admissible region for the priors.
159
+
160
+ - Under the **non-transformed** parameterization, all of :math:`\alpha, \beta^\star, \gamma^\star` should be
161
+ between 0 and 1.
162
+ - Under the **transformed** parameterization, :math:`\alpha \in (0, 1)`, :math:`\beta \in (0, \alpha)`, and
163
+ :math:`\gamma \in (0, 1 - \alpha)`
164
+
165
+ The :meth:`param_info` method will change to reflect the suggested intervals based on the value of this
166
+ argument.
167
+ dense_innovation_covariance: bool, default False
168
+ Whether to estimate a dense covariance for statespace innovations. In an ETS models, each observed variable
169
+ has a single source of stochastic variation. If True, these innovations are allowed to be correlated.
170
+ Ignored if ``k_endog == 1``
171
+ stationary_initialization: bool, default False
172
+ If True, the Kalman Filter's initial covariance matrix will be set to an approximate steady-state value.
173
+ The approximation is formed by adding a small dampening factor to each state. Specifically, the level state
174
+ for a ('A', 'N', 'N') model is written:
71
175
 
72
- \begin{align}
73
- \hat{y}_{t+1 | t} &= l_t \\
74
- l_t &= \alpha y_t + (1 - \alpha) l_{t-1} \\
75
- &= l_{t-1} + \alpha (y_t - l_{t-1})
76
- &= l_{t-1} + \alpha \epsilon_t
77
- \end{align}
176
+ .. math::
177
+ \ell_t = \ell_{t-1} + \alpha * e_t
78
178
 
79
- Where $\epsilon_t$ are the forecast errors, assumed to be IID mean zero and normally distributed. The role of
80
- :math:`\alpha` is clearest in the second line. The level of the time series at each time is a mixture of
81
- :math:`\alpha` percent of the incoming data, and :math:`1 - \alpha` percent of the previous level. Recursive
82
- substitution reveals that the level is a weighted composite of all previous observations; thus the name
83
- "Exponential Smoothing".
84
-
85
- Additional supposed specifications include:
179
+ That this system is not stationary can be understood in ARIMA terms: the level is a random walk; that is,
180
+ :math:`rho = 1`. This can be remedied by pretending that we instead have a dampened system:
86
181
 
87
- * `ETS(A,A,N)`: Holt's linear trend method
88
-
89
- .. math::
90
-
91
- \begin{align}
92
- y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
93
- l_t &= l_{t-1} + b_{t-1} + \alpha \epsilon_t \\
94
- b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t
95
- \end{align}
96
-
97
- [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
98
-
99
- * `ETS(A,N,A)`: Additive seasonal method
100
-
101
- .. math::
102
-
103
- \begin{align}
104
- y_t &= l_{t-1} + s_{t-m} + \epsilon_t \\
105
- l_t &= l_{t-1} + \alpha \epsilon_t \\
106
- s_t &= s_{t-m} + (1 - \alpha)\gamma^\star \epsilon_t
107
- \end{align}
108
-
109
- [1]_ also consider an alternative parameterization with :math:`\gamma = (1 - \alpha) \gamma^\star`.
110
-
111
- * `ETS(A,A,A)`: Additive Holt-Winters method
112
-
113
- .. math::
114
-
115
- \begin{align}
116
- y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
117
- l_t &= l_{t-1} + \alpha \epsilon_t \\
118
- b_t &= b_{t-1} + \alpha \beta^\star \epsilon_t \\
119
- s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
120
- \end{align}
121
-
122
- [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
123
- :math:`\gamma = (1 - \alpha) \gamma^\star`.
124
-
125
- * `ETS(A, Ad, N)`: Dampened trend method
126
-
127
- .. math::
128
-
129
- \begin{align}
130
- y_t &= l_{t-1} + b_{t-1} + \epsilon_t \\
131
- l_t &= l_{t-1} + \alpha \epsilon_t \\
132
- b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t
133
- \end{align}
134
-
135
- [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^\star`.
136
-
137
- * `ETS(A, Ad, A)`: Dampened trend with seasonal method
138
-
139
- .. math::
140
-
141
- \begin{align}
142
- y_t &= l_{t-1} + b_{t-1} + s_{t-m} + \epsilon_t \\
143
- l_t &= l_{t-1} + \alpha \epsilon_t \\
144
- b_t &= \phi b_{t-1} + \alpha \beta^\star \epsilon_t \\
145
- s_t &= s_{t-m} + (1 - \alpha) \gamma^\star \epsilon_t
146
- \end{align}
147
-
148
- [1]_ also consider an alternative parameterization with :math:`\beta = \alpha \beta^star` and
149
- :math:`\gamma = (1 - \alpha) \gamma^\star`.
150
-
151
-
152
- Parameters
153
- ----------
154
- order: tuple of string, Optional
155
- The exponential smoothing "order". This is a tuple of three strings, each of which should be one of 'A', 'Ad',
156
- or 'N'.
157
- If provided, the model will be initialized from the given order, and the `trend`, `damped_trend`, and `seasonal`
158
- arguments will be ignored.
159
- endog_names: str or list of str, Optional
160
- Names associated with observed states. If a list, the length should be equal to the number of time series
161
- to be estimated.
162
- k_endog: int, Optional
163
- Number of time series to estimate. If endog_names are provided, this is ignored and len(endog_names) is
164
- used instead.
165
- trend: bool
166
- Whether to include a trend component. Setting ``trend=True`` is equivalent to ``order[1] == 'A'``.
167
- damped_trend: bool
168
- Whether to include a damping parameter on the trend component. Ignored if `trend` is `False`. Setting
169
- ``trend=True`` and ``damped_trend=True`` is equivalent to order[1] == 'Ad'.
170
- seasonal: bool
171
- Whether to include a seasonal component. Setting ``seasonal=True`` is equivalent to ``order[2] = 'A'``.
172
- seasonal_periods: int
173
- The number of periods in a complete seasonal cycle. Ignored if `seasonal` is `False`
174
- (or if ``order[2] == "N"``)
175
- measurement_error: bool
176
- Whether to include a measurement error term in the model. Default is `False`.
177
- use_transformed_parameterization: bool, default False
178
- If true, use the :math:`\alpha, \beta, \gamma` parameterization, otherwise use the :math:`\alpha, \beta^\star,
179
- \gamma^\star` parameterization. This will change the admissible region for the priors.
180
-
181
- - Under the **non-transformed** parameterization, all of :math:`\alpha, \beta^\star, \gamma^\star` should be
182
- between 0 and 1.
183
- - Under the **transformed** parameterization, :math:`\alpha \in (0, 1)`, :math:`\beta \in (0, \alpha)`, and
184
- :math:`\gamma \in (0, 1 - \alpha)`
185
-
186
- The :meth:`param_info` method will change to reflect the suggested intervals based on the value of this
187
- argument.
188
- dense_innovation_covariance: bool, default False
189
- Whether to estimate a dense covariance for statespace innovations. In an ETS models, each observed variable
190
- has a single source of stochastic variation. If True, these innovations are allowed to be correlated.
191
- Ignored if ``k_endog == 1``
192
- stationary_initialization: bool, default False
193
- If True, the Kalman Filter's initial covariance matrix will be set to an approximate steady-state value.
194
- The approximation is formed by adding a small dampening factor to each state. Specifically, the level state
195
- for a ('A', 'N', 'N') model is written:
196
-
197
- .. math::
198
- \ell_t = \ell_{t-1} + \alpha * e_t
199
-
200
- That this system is not stationary can be understood in ARIMA terms: the level is a random walk; that is,
201
- :math:`rho = 1`. This can be remedied by pretending that we instead have a dampened system:
202
-
203
- .. math::
204
- \ell_t = \rho \ell_{t-1} + \alpha * e_t
205
-
206
- With :math:`\rho \approx 1`, the system is stationary, and we can solve for the steady-state covariance
207
- matrix. This is then used as the initial covariance matrix for the Kalman Filter. This is a heuristic
208
- method that helps avoid setting a prior on the initial covariance matrix.
209
- initialization_dampening: float, default 0.8
210
- Dampening factor to add to non-stationary model components. This is only used for initialization, it does
211
- *not* add dampening to the model. Ignored if `stationary_initialization` is `False`.
212
- filter_type: str, default "standard"
213
- The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
214
- and "cholesky". See the docs for kalman filters for more details.
215
- verbose: bool, default True
216
- If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
217
- mode: str or Mode, optional
218
- Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
219
- ``forecast``. The mode does **not** effect calls to ``pm.sample``.
220
-
221
- Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
222
- to all sampling methods.
223
-
224
-
225
- References
226
- ----------
227
- .. [1] Hyndman, Rob J., and George Athanasopoulos. Forecasting: principles and practice. OTexts, 2018.
228
- """
182
+ .. math::
183
+ \ell_t = \rho \ell_{t-1} + \alpha * e_t
184
+
185
+ With :math:`\rho \approx 1`, the system is stationary, and we can solve for the steady-state covariance
186
+ matrix. This is then used as the initial covariance matrix for the Kalman Filter. This is a heuristic
187
+ method that helps avoid setting a prior on the initial covariance matrix.
188
+ initialization_dampening: float, default 0.8
189
+ Dampening factor to add to non-stationary model components. This is only used for initialization, it does
190
+ *not* add dampening to the model. Ignored if `stationary_initialization` is `False`.
191
+ filter_type: str, default "standard"
192
+ The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
193
+ and "cholesky". See the docs for kalman filters for more details.
194
+ verbose: bool, default True
195
+ If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
196
+ mode: str or Mode, optional
197
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
198
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
199
+
200
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
201
+ to all sampling methods.
202
+
203
+
204
+ References
205
+ ----------
206
+ .. [1] Hyndman, Rob J., and George Athanasopoulos. Forecasting: principles and practice. OTexts, 2018.
207
+ """
229
208
 
209
+ def __init__(
210
+ self,
211
+ order: tuple[str, str, str] | None = None,
212
+ endog_names: str | list[str] | None = None,
213
+ trend: bool = True,
214
+ damped_trend: bool = False,
215
+ seasonal: bool = False,
216
+ seasonal_periods: int | None = None,
217
+ measurement_error: bool = False,
218
+ use_transformed_parameterization: bool = False,
219
+ dense_innovation_covariance: bool = False,
220
+ stationary_initialization: bool = False,
221
+ initialization_dampening: float = 0.8,
222
+ filter_type: str = "standard",
223
+ verbose: bool = True,
224
+ mode: str | Mode | None = None,
225
+ ):
230
226
  if order is not None:
231
227
  if len(order) != 3 or any(not isinstance(o, str) for o in order):
232
228
  raise ValueError("Order must be a tuple of three strings.")
@@ -265,13 +261,9 @@ class BayesianETS(PyMCStateSpace):
265
261
  if self.seasonal and self.seasonal_periods is None:
266
262
  raise ValueError("If seasonal is True, seasonal_periods must be provided.")
267
263
 
268
- if endog_names is not None:
269
- endog_names = list(endog_names)
270
- k_endog = len(endog_names)
271
- else:
272
- endog_names = [f"data_{i}" for i in range(k_endog)] if k_endog > 1 else ["data"]
273
-
274
- self.endog_names = endog_names
264
+ validate_names(endog_names, var_name="endog_names", optional=False)
265
+ k_endog = len(endog_names)
266
+ self.endog_names = list(endog_names)
275
267
 
276
268
  if dense_innovation_covariance and k_endog == 1:
277
269
  dense_innovation_covariance = False
@@ -12,12 +12,13 @@ from pymc_extras.statespace.models.utilities import (
12
12
  make_default_coords,
13
13
  make_harvey_state_names,
14
14
  make_SARIMA_transition_matrix,
15
+ validate_names,
15
16
  )
16
17
  from pymc_extras.statespace.utils.constants import (
17
18
  ALL_STATE_AUX_DIM,
18
19
  ALL_STATE_DIM,
19
20
  AR_PARAM_DIM,
20
- EXOGENOUS_DIM,
21
+ EXOG_STATE_DIM,
21
22
  MA_PARAM_DIM,
22
23
  OBS_STATE_DIM,
23
24
  SARIMAX_STATE_STRUCTURES,
@@ -132,7 +133,6 @@ class BayesianSARIMAX(PyMCStateSpace):
132
133
  order: tuple[int, int, int],
133
134
  seasonal_order: tuple[int, int, int, int] | None = None,
134
135
  exog_state_names: list[str] | None = None,
135
- k_exog: int | None = None,
136
136
  stationary_initialization: bool = True,
137
137
  filter_type: str = "standard",
138
138
  state_structure: str = "fast",
@@ -166,10 +166,6 @@ class BayesianSARIMAX(PyMCStateSpace):
166
166
  exog_state_names : list[str], optional
167
167
  Names of the exogenous state variables.
168
168
 
169
- k_exog : int, optional
170
- Number of exogenous variables. If provided, must match the length of
171
- `exog_state_names`.
172
-
173
169
  stationary_initialization : bool, default True
174
170
  If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
175
171
  state values will be used.
@@ -212,18 +208,10 @@ class BayesianSARIMAX(PyMCStateSpace):
212
208
  if seasonal_order is None:
213
209
  seasonal_order = (0, 0, 0, 0)
214
210
 
215
- if exog_state_names is None and k_exog is not None:
216
- exog_state_names = [f"exogenous_{i}" for i in range(k_exog)]
217
- elif exog_state_names is not None and k_exog is None:
218
- k_exog = len(exog_state_names)
219
- elif exog_state_names is not None and k_exog is not None:
220
- if len(exog_state_names) != k_exog:
221
- raise ValueError(
222
- f"Based on provided inputs, expected exog_state_names to have {k_exog} elements, but "
223
- f"found {len(exog_state_names)}"
224
- )
225
- else:
226
- k_exog = 0
211
+ validate_names(
212
+ exog_state_names, var_name="exog_state_names", optional=True
213
+ ) # Not sure if this adds anything
214
+ k_exog = len(exog_state_names) if exog_state_names is not None else 0
227
215
 
228
216
  self.exog_state_names = exog_state_names
229
217
  self.k_exog = k_exog
@@ -315,7 +303,7 @@ class BayesianSARIMAX(PyMCStateSpace):
315
303
  def data_info(self) -> dict[str, dict[str, Any]]:
316
304
  info = {
317
305
  "exogenous_data": {
318
- "dims": (TIME_DIM, EXOGENOUS_DIM),
306
+ "dims": (TIME_DIM, EXOG_STATE_DIM),
319
307
  "shape": (None, self.k_exog),
320
308
  }
321
309
  }
@@ -403,7 +391,7 @@ class BayesianSARIMAX(PyMCStateSpace):
403
391
  "ma_params": (MA_PARAM_DIM,),
404
392
  "seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
405
393
  "seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
406
- "beta_exog": (EXOGENOUS_DIM,),
394
+ "beta_exog": (EXOG_STATE_DIM,),
407
395
  }
408
396
  if self.k_endog == 1:
409
397
  coord_map["sigma_state"] = None
@@ -438,7 +426,7 @@ class BayesianSARIMAX(PyMCStateSpace):
438
426
  if self.Q > 0:
439
427
  coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
440
428
  if self.k_exog > 0:
441
- coords.update({EXOGENOUS_DIM: self.exog_state_names})
429
+ coords.update({EXOG_STATE_DIM: self.exog_state_names})
442
430
  return coords
443
431
 
444
432
  def _stationary_initialization(self):