pymc-extras 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. pymc_extras/distributions/__init__.py +5 -5
  2. pymc_extras/distributions/histogram_utils.py +1 -1
  3. pymc_extras/inference/__init__.py +1 -1
  4. pymc_extras/printing.py +1 -1
  5. pymc_extras/statespace/__init__.py +4 -4
  6. pymc_extras/statespace/core/__init__.py +1 -1
  7. pymc_extras/statespace/core/representation.py +8 -8
  8. pymc_extras/statespace/core/statespace.py +94 -23
  9. pymc_extras/statespace/filters/__init__.py +3 -3
  10. pymc_extras/statespace/filters/kalman_filter.py +16 -11
  11. pymc_extras/statespace/models/SARIMAX.py +138 -74
  12. pymc_extras/statespace/models/VARMAX.py +248 -57
  13. pymc_extras/statespace/models/__init__.py +2 -2
  14. pymc_extras/statespace/models/structural/__init__.py +4 -4
  15. pymc_extras/statespace/models/structural/components/autoregressive.py +49 -24
  16. pymc_extras/statespace/models/structural/components/cycle.py +48 -28
  17. pymc_extras/statespace/models/structural/components/level_trend.py +61 -29
  18. pymc_extras/statespace/models/structural/components/measurement_error.py +22 -5
  19. pymc_extras/statespace/models/structural/components/regression.py +47 -18
  20. pymc_extras/statespace/models/structural/components/seasonality.py +278 -95
  21. pymc_extras/statespace/models/structural/core.py +27 -8
  22. pymc_extras/statespace/utils/constants.py +17 -14
  23. pymc_extras/statespace/utils/data_tools.py +1 -1
  24. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/METADATA +1 -1
  25. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/RECORD +27 -27
  26. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/WHEEL +0 -0
  27. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -14,24 +14,29 @@ class TimeSeasonality(Component):
14
14
  ----------
15
15
  season_length: int
16
16
  The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
17
- daily data with weekly seasonal pattern, etc.
17
+ daily data with weekly seasonal pattern, etc. It must be greater than one.
18
+
19
+ duration: int, default 1
20
+ Number of time steps for each seasonal period.
21
+ This determines how long each seasonal period is held constant before moving to the next.
18
22
 
19
23
  innovations: bool, default True
20
24
  Whether to include stochastic innovations in the strength of the seasonal effect
21
25
 
22
26
  name: str, default None
23
27
  A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
24
- components are included in the same model. Default is ``f"Seasonal[s={season_length}]"``
28
+ components are included in the same model. Default is ``f"Seasonal[s={season_length}, d={duration}]"``
25
29
 
26
30
  state_names: list of str, default None
27
- List of strings for seasonal effect labels. If provided, it must be of length ``season_length``. An example
28
- would be ``state_names = ['Mon', 'Tue', 'Wed', 'Thur', 'Fri', 'Sat', 'Sun']`` when data is daily with a weekly
31
+ List of strings for seasonal effect labels. If provided, it must be of length ``season_length`` times ``duration``.
32
+ An example would be ``state_names = ['Mon', 'Tue', 'Wed', 'Thur', 'Fri', 'Sat', 'Sun']`` when data is daily with a weekly
29
33
  seasonal pattern (``season_length = 7``).
30
34
 
31
- If None, states will be numbered ``[State_0, ..., State_s]``
35
+ If None and ``duration = 1``, states will be named as ``[State_0, ..., State_s-1]`` (here s is ``season_length``).
36
+ If None and ``duration > 1``, states will be named as ``[State_0_0, ..., State_s-1_d-1]`` (here d is ``duration``).
32
37
 
33
38
  remove_first_state: bool, default True
34
- If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
39
+ If True, the first state will be removed from the model. This is done because there are only ``season_length-1`` degrees of
35
40
  freedom in the seasonal component, and one state is not identified. If False, the first state will be
36
41
  included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
37
42
  ZeroSumNormal).
@@ -39,19 +44,83 @@ class TimeSeasonality(Component):
39
44
  observed_state_names: list[str] | None, default None
40
45
  List of strings for observed state labels. If None, defaults to ["data"].
41
46
 
47
+ share_states: bool, default False
48
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
49
+ states, which are observed by all observed states. If False, each observed state has its own set of
50
+ latent states. This argument has no effect if `k_endog` is 1.
51
+
42
52
  Notes
43
53
  -----
44
- A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
45
- model seasonal effects, the implementation used here is the one described by [1] as the "canonical" time domain
46
- representation. The seasonal component can be expressed:
54
+ A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
55
+ here, we present two models that are straightforward extensions of those described in [1].
56
+
57
+ **First model** (``remove_first_state=True``)
58
+
59
+ In this model, the state vector is defined as:
47
60
 
48
61
  .. math::
49
- \gamma_t = -\sum_{i=1}^{s-1} \gamma_{t-i} + \omega_t, \quad \omega_t \sim N(0, \sigma_\gamma)
62
+ \alpha_t :=(\gamma_t, \ldots, \gamma_{t-d(s-1)+1}), \quad t \ge 0.
63
+
64
+ This vector has length :math:`d(s-1)`, where:
65
+
66
+ - :math:`s` is the ``seasonal_length`` parameter, and
67
+ - :math:`d` is the ``duration`` parameter.
68
+
69
+ The components of the initial vector :math:`\alpha_{0}` are given by
70
+
71
+ .. math::
72
+ \gamma_{-l} := \tilde{\gamma}_{k_l}, \quad \text{where} \quad k_l := \left\lfloor \frac{l}{d} \right\rfloor \bmod s \quad \text{and} \quad l=0,\ldots, d(s-1)-1.
73
+
74
+ Here, the values
75
+
76
+ .. math::
77
+ \tilde{\gamma}_{0}, \ldots, \tilde{\gamma}_{s-2},
78
+
79
+ represent the initial seasonal states. The transition matrix of this model is the :math:`d(s-1) \times d(s-1)` matrix
80
+
81
+ .. math::
82
+ \begin{bmatrix}
83
+ -\mathbf{1}_d & -\mathbf{1}_d & \cdots & -\mathbf{1}_d & -\mathbf{1}_d \\
84
+ \mathbf{1}_d & \mathbf{0}_d & \cdots & \mathbf{0}_d & \mathbf{0}_d \\
85
+ \mathbf{0}_d & \mathbf{1}_d & \cdots & \mathbf{0}_d & \mathbf{0}_d \\
86
+ \vdots & \vdots & \ddots & \vdots \\
87
+ \mathbf{0}_d & \mathbf{0}_d & \cdots & \mathbf{1}_d & \mathbf{0}_d
88
+ \end{bmatrix}
89
+
90
+ where :math:`\mathbf{1}_d` and :math:`\mathbf{0}_d` denote the :math:`d \times d` identity and null matrices, respectively.
91
+
92
+ **Second model** (``remove_first_state=False``)
50
93
 
51
- Where :math:`s` is the ``seasonal_length`` parameter and :math:`\omega_t` is the (optional) stochastic innovation.
52
- To give interpretation to the :math:`\gamma` terms, it is helpful to work through the algebra for a simple
53
- example. Let :math:`s=4`, and omit the shock term. Define initial conditions :math:`\gamma_0, \gamma_{-1},
54
- \gamma_{-2}`. The value of the seasonal component for the first 5 timesteps will be:
94
+ In contrast, the state vector in the second model is defined as:
95
+
96
+ .. math::
97
+ \alpha_t=(\gamma_t, \ldots, \gamma_{t-ds+1}), \quad t \ge 0.
98
+
99
+ This vector has length :math:`ds`. The components of the initial state vector :math:`\alpha_{0}` are defined similarly:
100
+
101
+ .. math::
102
+ \gamma_{-l} := \tilde{\gamma}_{k_l}, \quad \text{where} \quad k_l := \left\lfloor \frac{l}{d} \right\rfloor \bmod s \quad \text{and} \quad l=0,\ldots, ds-1.
103
+
104
+ In this case, the initial seasonal states :math:`\tilde{\gamma}_{0}, \ldots, \tilde{\gamma}_{s-1}` are required to satisfy the following condition:
105
+
106
+ .. math::
107
+ \sum_{i=0}^{s-1} \tilde{\gamma}_{i} = 0.
108
+
109
+ The transition matrix of this model is the following :math:`ds \times ds` circulant matrix:
110
+
111
+ .. math::
112
+ \begin{bmatrix}
113
+ 0 & 1 & 0 & \cdots & 0 \\
114
+ 0 & 0 & 1 & \cdots & 0 \\
115
+ \vdots & \vdots & \ddots & \ddots & \vdots \\
116
+ 0 & 0 & \cdots & 0 & 1 \\
117
+ 1 & 0 & \cdots & 0 & 0
118
+ \end{bmatrix}
119
+
120
+ To give interpretation to the :math:`\gamma` terms, it is helpful to work through the algebra for a simple
121
+ example. Let :math:`s=4`, :math:`d=1`, ``remove_first_state=True``, and omit the shock term. Then, we have
122
+ :math:`\gamma_{-i} = \tilde{\gamma}_{-i}`, for :math:`i=-2,\ldots, 0` and the value of the seasonal component
123
+ for the first 5 timesteps will be:
55
124
 
56
125
  .. math::
57
126
  \begin{align}
@@ -85,10 +154,38 @@ class TimeSeasonality(Component):
85
154
  And so on. So for interpretation, the ``season_length - 1`` initial states are, when reversed, the coefficients
86
155
  associated with ``state_names[1:]``.
87
156
 
157
+ In the next example, we set :math:`s=2`, :math:`d=2`, ``remove_first_state=True``, and omit the shock term.
158
+ By definition, the initial vector :math:`\alpha_{0}` is
159
+
160
+ .. math::
161
+ \alpha_0=(\tilde{\gamma}_{0}, \tilde{\gamma}_{0}, \tilde{\gamma}_{-1}, \tilde{\gamma}_{-1})
162
+
163
+ and the transition matrix is
164
+
165
+ .. math::
166
+ \begin{bmatrix}
167
+ -1 & 0 & -1 & 0 \\
168
+ 0 & -1 & 0 & -1 \\
169
+ 1 & 0 & 0 & 0 \\
170
+ 0 & 1 & 0 & 0 \\
171
+ \end{bmatrix}
172
+
173
+ It is easy to verify that:
174
+
175
+ .. math::
176
+ \begin{align}
177
+ \gamma_1 &= -\tilde{\gamma}_0 - \tilde{\gamma}_{-1}\\
178
+ \gamma_2 &= -(-\tilde{\gamma}_0 - \tilde{\gamma}_{-1})-\tilde{\gamma}_0\\
179
+ &= \tilde{\gamma}_{-1}\\
180
+ \gamma_3 &= -\tilde{\gamma}_{-1} +(\tilde{\gamma}_0 + \tilde{\gamma}_{-1})\\
181
+ &= \tilde{\gamma}_{0}\\
182
+ \gamma_4 &= -\tilde{\gamma}_0 - \tilde{\gamma}_{-1}.\\
183
+ \end{align}
184
+
88
185
  .. warning::
89
- Although the ``state_names`` argument expects a list of length ``season_length``, only ``state_names[1:]``
90
- will be saved as model dimensions, since the 1st coefficient is not identified (it is defined as
91
- :math:`-\sum_{i=1}^{s} \gamma_{t-i}`).
186
+ Although the ``state_names`` argument expects a list of length ``season_length`` times ``duration``,
187
+ only ``state_names[duration:]`` will be saved as model dimensions, since the first coefficient is not identified
188
+ (it is defined as :math:`-\sum_{i=1}^{s-1} \tilde{\gamma}_{-i}`).
92
189
 
93
190
  Examples
94
191
  --------
@@ -120,7 +217,7 @@ class TimeSeasonality(Component):
120
217
  sigma_level_trend = pm.HalfNormal(
121
218
  "sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
122
219
  )
123
- coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"])
220
+ params_annual = pm.Normal("params_annual", sigma=1e-2, dims=ss_mod.param_dims["params_annual"])
124
221
 
125
222
  ss_mod.build_statespace_graph(data)
126
223
  idata = pm.sample(
@@ -137,82 +234,109 @@ class TimeSeasonality(Component):
137
234
  def __init__(
138
235
  self,
139
236
  season_length: int,
237
+ duration: int = 1,
140
238
  innovations: bool = True,
141
239
  name: str | None = None,
142
240
  state_names: list | None = None,
143
241
  remove_first_state: bool = True,
144
242
  observed_state_names: list[str] | None = None,
243
+ share_states: bool = False,
145
244
  ):
146
245
  if observed_state_names is None:
147
246
  observed_state_names = ["data"]
148
247
 
248
+ if season_length <= 1 or not isinstance(season_length, int):
249
+ raise ValueError(
250
+ f"season_length must be an integer greater than 1, got {season_length}"
251
+ )
252
+ if duration <= 0 or not isinstance(duration, int):
253
+ raise ValueError(f"duration must be a positive integer, got {duration}")
149
254
  if name is None:
150
- name = f"Seasonal[s={season_length}]"
255
+ name = f"Seasonal[s={season_length}, d={duration}]"
151
256
  if state_names is None:
152
- state_names = [f"{name}_{i}" for i in range(season_length)]
257
+ if duration > 1:
258
+ state_names = [
259
+ f"{name}_{i}_{j}" for i in range(season_length) for j in range(duration)
260
+ ]
261
+ else:
262
+ state_names = [f"{name}_{i}" for i in range(season_length)]
153
263
  else:
154
- if len(state_names) != season_length:
264
+ if len(state_names) != season_length * duration:
155
265
  raise ValueError(
156
- f"state_names must be a list of length season_length, got {len(state_names)}"
266
+ f"state_names must be a list of length season_length*duration, got {len(state_names)}"
157
267
  )
158
268
  state_names = state_names.copy()
159
269
 
270
+ self.share_states = share_states
160
271
  self.innovations = innovations
272
+ self.duration = duration
161
273
  self.remove_first_state = remove_first_state
274
+ self.season_length = season_length
162
275
 
163
276
  if self.remove_first_state:
164
277
  # In traditional models, the first state isn't identified, so we can help out the user by automatically
165
278
  # discarding it.
166
279
  # TODO: Can this be stashed and reconstructed automatically somehow?
167
- state_names.pop(0)
280
+ state_names = state_names[duration:]
168
281
 
169
282
  self.provided_state_names = state_names
170
283
 
171
- k_states = season_length - int(self.remove_first_state)
284
+ k_states = (season_length - int(self.remove_first_state)) * duration
172
285
  k_endog = len(observed_state_names)
173
286
  k_posdef = int(innovations)
174
287
 
175
288
  super().__init__(
176
289
  name=name,
177
290
  k_endog=k_endog,
178
- k_states=k_states * k_endog,
179
- k_posdef=k_posdef * k_endog,
291
+ k_states=k_states if share_states else k_states * k_endog,
292
+ k_posdef=k_posdef if share_states else k_posdef * k_endog,
180
293
  observed_state_names=observed_state_names,
181
294
  measurement_error=False,
182
295
  combine_hidden_states=True,
183
- obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
296
+ obs_state_idxs=np.tile(
297
+ np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog
298
+ ),
299
+ share_states=share_states,
184
300
  )
185
301
 
186
302
  def populate_component_properties(self):
187
- k_states = self.k_states // self.k_endog
188
303
  k_endog = self.k_endog
304
+ k_endog_effective = 1 if self.share_states else k_endog
305
+
306
+ k_states = self.k_states // k_endog_effective
189
307
 
190
- self.state_names = [
191
- f"{state_name}[{endog_name}]"
192
- for endog_name in self.observed_state_names
193
- for state_name in self.provided_state_names
194
- ]
195
- self.param_names = [f"coefs_{self.name}"]
308
+ if self.share_states:
309
+ self.state_names = [
310
+ f"{state_name}[{self.name}_shared]" for state_name in self.provided_state_names
311
+ ]
312
+ else:
313
+ self.state_names = [
314
+ f"{state_name}[{endog_name}]"
315
+ for endog_name in self.observed_state_names
316
+ for state_name in self.provided_state_names
317
+ ]
318
+
319
+ self.param_names = [f"params_{self.name}"]
196
320
 
197
321
  self.param_info = {
198
- f"coefs_{self.name}": {
322
+ f"params_{self.name}": {
199
323
  "shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
200
324
  "constraints": None,
201
325
  "dims": (f"state_{self.name}",)
202
- if k_endog == 1
326
+ if k_endog_effective == 1
203
327
  else (f"endog_{self.name}", f"state_{self.name}"),
204
328
  }
205
329
  }
206
330
 
207
331
  self.param_dims = {
208
- f"coefs_{self.name}": (f"state_{self.name}",)
209
- if k_endog == 1
332
+ f"params_{self.name}": (f"state_{self.name}",)
333
+ if k_endog_effective == 1
210
334
  else (f"endog_{self.name}", f"state_{self.name}")
211
335
  }
212
336
 
213
337
  self.coords = (
214
338
  {f"state_{self.name}": self.provided_state_names}
215
- if k_endog == 1
339
+ if k_endog_effective == 1
216
340
  else {
217
341
  f"endog_{self.name}": self.observed_state_names,
218
342
  f"state_{self.name}": self.provided_state_names,
@@ -222,45 +346,82 @@ class TimeSeasonality(Component):
222
346
  if self.innovations:
223
347
  self.param_names += [f"sigma_{self.name}"]
224
348
  self.param_info[f"sigma_{self.name}"] = {
225
- "shape": (),
349
+ "shape": () if k_endog_effective == 1 else (k_endog,),
226
350
  "constraints": "Positive",
227
- "dims": None,
351
+ "dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
228
352
  }
229
- self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
353
+ if self.share_states:
354
+ self.shock_names = [f"{self.name}[shared]"]
355
+ else:
356
+ self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
357
+
358
+ if k_endog > 1:
359
+ self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
230
360
 
231
361
  def make_symbolic_graph(self) -> None:
232
- k_states = self.k_states // self.k_endog
233
- k_posdef = self.k_posdef // self.k_endog
234
362
  k_endog = self.k_endog
363
+ k_endog_effective = 1 if self.share_states else k_endog
364
+ k_states = self.k_states // k_endog_effective
365
+ duration = self.duration
366
+
367
+ k_unique_states = k_states // duration
368
+ k_posdef = self.k_posdef // k_endog_effective
235
369
 
236
370
  if self.remove_first_state:
237
371
  # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
238
372
  # all previous states.
239
- T = np.eye(k_states, k=-1)
240
- T[0, :] = -1
373
+ zero_d = pt.zeros((self.duration, self.duration))
374
+ id_d = pt.eye(self.duration)
375
+
376
+ row_blocks = []
377
+
378
+ # First row: all -1_d blocks
379
+ first_row = [-id_d for _ in range(self.season_length - 1)]
380
+ row_blocks.append(pt.concatenate(first_row, axis=1))
381
+
382
+ # Rows 2 to season_length-1: shifted identity blocks
383
+ for i in range(self.season_length - 2):
384
+ row = []
385
+ for j in range(self.season_length - 1):
386
+ if j == i:
387
+ row.append(id_d)
388
+ else:
389
+ row.append(zero_d)
390
+ row_blocks.append(pt.concatenate(row, axis=1))
391
+
392
+ # Stack blocks
393
+ T = pt.concatenate(row_blocks, axis=0)
241
394
  else:
242
395
  # In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
243
396
  # circulant matrix that cycles between the states.
244
- T = np.eye(k_states, k=1)
245
- T[-1, 0] = 1
397
+ T = pt.eye(k_states, k=1)
398
+ T = pt.set_subtensor(T[-1, 0], 1)
246
399
 
247
- self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
400
+ self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
248
401
 
249
402
  Z = pt.zeros((1, k_states))[0, 0].set(1)
250
- self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
403
+ self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
251
404
 
252
405
  initial_states = self.make_and_register_variable(
253
- f"coefs_{self.name}", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
406
+ f"params_{self.name}",
407
+ shape=(k_unique_states,)
408
+ if k_endog_effective == 1
409
+ else (k_endog_effective, k_unique_states),
254
410
  )
255
- self.ssm["initial_state", :] = initial_states.ravel()
411
+ if k_endog_effective == 1:
412
+ self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0)
413
+ else:
414
+ self.ssm["initial_state", :] = pt.extra_ops.repeat(
415
+ initial_states, duration, axis=1
416
+ ).ravel()
256
417
 
257
418
  if self.innovations:
258
419
  R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
259
- self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
420
+ self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)])
260
421
  season_sigma = self.make_and_register_variable(
261
- f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
422
+ f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
262
423
  )
263
- cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
424
+ cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective))
264
425
  self.ssm[cov_idx] = season_sigma**2
265
426
 
266
427
 
@@ -289,6 +450,11 @@ class FrequencySeasonality(Component):
289
450
  observed_state_names: list[str] | None, default None
290
451
  List of strings for observed state labels. If None, defaults to ["data"].
291
452
 
453
+ share_states: bool, default False
454
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
455
+ states, which are observed by all observed states. If False, each observed state has its own set of
456
+ latent states. This argument has no effect if `k_endog` is 1.
457
+
292
458
  Notes
293
459
  -----
294
460
  A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -320,15 +486,17 @@ class FrequencySeasonality(Component):
320
486
 
321
487
  def __init__(
322
488
  self,
323
- season_length,
324
- n=None,
325
- name=None,
326
- innovations=True,
489
+ season_length: int,
490
+ n: int | None = None,
491
+ name: str | None = None,
492
+ innovations: bool = True,
327
493
  observed_state_names: list[str] | None = None,
494
+ share_states: bool = False,
328
495
  ):
329
496
  if observed_state_names is None:
330
497
  observed_state_names = ["data"]
331
498
 
499
+ self.share_states = share_states
332
500
  k_endog = len(observed_state_names)
333
501
 
334
502
  if n is None:
@@ -344,18 +512,21 @@ class FrequencySeasonality(Component):
344
512
  # If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
345
513
  # get a parameter assigned to it and should just be fixed to zero.
346
514
  # Test this way (rather than n == s // 2) to catch cases when n is non-integer.
347
- self.last_state_not_identified = self.season_length / self.n == 2.0
515
+ self.last_state_not_identified = (self.season_length / self.n) == 2.0
348
516
  self.n_coefs = k_states - int(self.last_state_not_identified)
349
517
 
350
518
  obs_state_idx = np.zeros(k_states)
351
519
  obs_state_idx[slice(0, k_states, 2)] = 1
352
- obs_state_idx = np.tile(obs_state_idx, k_endog)
520
+ obs_state_idx = np.tile(obs_state_idx, 1 if share_states else k_endog)
353
521
 
354
522
  super().__init__(
355
523
  name=name,
356
524
  k_endog=k_endog,
357
- k_states=k_states * k_endog,
358
- k_posdef=k_states * int(self.innovations) * k_endog,
525
+ k_states=k_states if share_states else k_states * k_endog,
526
+ k_posdef=k_states * int(self.innovations)
527
+ if share_states
528
+ else k_states * int(self.innovations) * k_endog,
529
+ share_states=share_states,
359
530
  observed_state_names=observed_state_names,
360
531
  measurement_error=False,
361
532
  combine_hidden_states=True,
@@ -364,22 +535,24 @@ class FrequencySeasonality(Component):
364
535
 
365
536
  def make_symbolic_graph(self) -> None:
366
537
  k_endog = self.k_endog
367
- k_states = self.k_states // k_endog
368
- k_posdef = self.k_posdef // k_endog
538
+ k_endog_effective = 1 if self.share_states else k_endog
539
+
540
+ k_states = self.k_states // k_endog_effective
541
+ k_posdef = self.k_posdef // k_endog_effective
369
542
  n_coefs = self.n_coefs
370
543
 
371
544
  Z = pt.zeros((1, k_states))[0, slice(0, k_states, 2)].set(1.0)
372
545
 
373
- self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
546
+ self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
374
547
 
375
548
  init_state = self.make_and_register_variable(
376
- f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
549
+ f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
377
550
  )
378
551
 
379
552
  init_state_idx = np.concatenate(
380
553
  [
381
554
  np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
382
- for i in range(k_endog)
555
+ for i in range(k_endog_effective)
383
556
  ],
384
557
  axis=0,
385
558
  )
@@ -388,11 +561,11 @@ class FrequencySeasonality(Component):
388
561
 
389
562
  T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
390
563
  T = pt.linalg.block_diag(*T_mats)
391
- self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
564
+ self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
392
565
 
393
566
  if self.innovations:
394
567
  sigma_season = self.make_and_register_variable(
395
- f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
568
+ f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
396
569
  )
397
570
  self.ssm["selection", :, :] = pt.eye(self.k_states)
398
571
  self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * pt.repeat(
@@ -401,45 +574,55 @@ class FrequencySeasonality(Component):
401
574
 
402
575
  def populate_component_properties(self):
403
576
  k_endog = self.k_endog
577
+ k_endog_effective = 1 if self.share_states else k_endog
404
578
  n_coefs = self.n_coefs
405
- k_states = self.k_states // k_endog
406
579
 
407
- self.state_names = [
408
- f"{f}_{self.name}_{i}[{obs_state_name}]"
409
- for obs_state_name in self.observed_state_names
410
- for i in range(self.n)
411
- for f in ["Cos", "Sin"]
412
- ]
413
- self.param_names = [f"{self.name}"]
580
+ base_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]]
414
581
 
415
- self.param_dims = {self.name: (f"state_{self.name}",)}
582
+ if self.share_states:
583
+ self.state_names = [f"{name}[shared]" for name in base_names]
584
+ else:
585
+ self.state_names = [
586
+ f"{name}[{obs_state_name}]"
587
+ for obs_state_name in self.observed_state_names
588
+ for name in base_names
589
+ ]
590
+
591
+ # Trim state names if the model is saturated
592
+ param_state_names = base_names[:n_coefs]
593
+
594
+ self.param_names = [f"params_{self.name}"]
595
+ self.param_dims = {
596
+ f"params_{self.name}": (f"state_{self.name}",)
597
+ if k_endog_effective == 1
598
+ else (f"endog_{self.name}", f"state_{self.name}")
599
+ }
416
600
  self.param_info = {
417
- f"{self.name}": {
418
- "shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
601
+ f"params_{self.name}": {
602
+ "shape": (n_coefs,) if k_endog_effective == 1 else (k_endog_effective, n_coefs),
419
603
  "constraints": None,
420
604
  "dims": (f"state_{self.name}",)
421
- if k_endog == 1
605
+ if k_endog_effective == 1
422
606
  else (f"endog_{self.name}", f"state_{self.name}"),
423
607
  }
424
608
  }
425
609
 
426
- # Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
427
- # That's why the self.states is just a simple loop over everything. But when saturated, one of those states
428
- # doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
429
- init_state_idx = np.concatenate(
430
- [
431
- np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
432
- for i in range(k_endog)
433
- ],
434
- axis=0,
610
+ self.coords = (
611
+ {f"state_{self.name}": param_state_names}
612
+ if k_endog == 1
613
+ else {
614
+ f"endog_{self.name}": self.observed_state_names,
615
+ f"state_{self.name}": param_state_names,
616
+ }
435
617
  )
436
- self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}
437
618
 
438
619
  if self.innovations:
439
- self.shock_names = self.state_names.copy()
440
620
  self.param_names += [f"sigma_{self.name}"]
621
+ self.shock_names = self.state_names.copy()
441
622
  self.param_info[f"sigma_{self.name}"] = {
442
- "shape": () if k_endog == 1 else (k_endog, n_coefs),
623
+ "shape": () if k_endog_effective == 1 else (k_endog_effective, n_coefs),
443
624
  "constraints": "Positive",
444
- "dims": None if k_endog == 1 else (f"endog_{self.name}",),
625
+ "dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
445
626
  }
627
+ if k_endog_effective > 1:
628
+ self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)