pymc-extras 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. pymc_extras/distributions/__init__.py +5 -5
  2. pymc_extras/distributions/histogram_utils.py +1 -1
  3. pymc_extras/inference/__init__.py +1 -1
  4. pymc_extras/printing.py +1 -1
  5. pymc_extras/statespace/__init__.py +4 -4
  6. pymc_extras/statespace/core/__init__.py +1 -1
  7. pymc_extras/statespace/core/representation.py +8 -8
  8. pymc_extras/statespace/core/statespace.py +94 -23
  9. pymc_extras/statespace/filters/__init__.py +3 -3
  10. pymc_extras/statespace/filters/kalman_filter.py +16 -11
  11. pymc_extras/statespace/models/SARIMAX.py +138 -74
  12. pymc_extras/statespace/models/VARMAX.py +248 -57
  13. pymc_extras/statespace/models/__init__.py +2 -2
  14. pymc_extras/statespace/models/structural/__init__.py +4 -4
  15. pymc_extras/statespace/models/structural/components/autoregressive.py +49 -24
  16. pymc_extras/statespace/models/structural/components/cycle.py +48 -28
  17. pymc_extras/statespace/models/structural/components/level_trend.py +61 -29
  18. pymc_extras/statespace/models/structural/components/measurement_error.py +22 -5
  19. pymc_extras/statespace/models/structural/components/regression.py +47 -18
  20. pymc_extras/statespace/models/structural/components/seasonality.py +278 -95
  21. pymc_extras/statespace/models/structural/core.py +27 -8
  22. pymc_extras/statespace/utils/constants.py +17 -14
  23. pymc_extras/statespace/utils/data_tools.py +1 -1
  24. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/METADATA +1 -1
  25. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/RECORD +27 -27
  26. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/WHEEL +0 -0
  27. {pymc_extras-0.4.0.dist-info → pymc_extras-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -43,6 +43,11 @@ class CycleComponent(Component):
43
43
  Names of the observed state variables. For univariate time series, defaults to ``["data"]``.
44
44
  For multivariate time series, specify a list of names for each endogenous variable.
45
45
 
46
+ share_states: bool, default False
47
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
48
+ states, which are observed by all observed states. If False, each observed state has its own set of
49
+ latent states. This argument has no effect if `k_endog` is 1.
50
+
46
51
  Notes
47
52
  -----
48
53
  The cycle component is very similar in implementation to the frequency domain seasonal component, expect that it
@@ -155,6 +160,7 @@ class CycleComponent(Component):
155
160
  dampen: bool = False,
156
161
  innovations: bool = True,
157
162
  observed_state_names: list[str] | None = None,
163
+ share_states: bool = False,
158
164
  ):
159
165
  if observed_state_names is None:
160
166
  observed_state_names = ["data"]
@@ -167,6 +173,7 @@ class CycleComponent(Component):
167
173
  cycle = int(cycle_length) if cycle_length is not None else "Estimate"
168
174
  name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]"
169
175
 
176
+ self.share_states = share_states
170
177
  self.estimate_cycle_length = estimate_cycle_length
171
178
  self.cycle_length = cycle_length
172
179
  self.innovations = innovations
@@ -175,8 +182,8 @@ class CycleComponent(Component):
175
182
 
176
183
  k_endog = len(observed_state_names)
177
184
 
178
- k_states = 2 * k_endog
179
- k_posdef = 2 * k_endog
185
+ k_states = 2 if share_states else 2 * k_endog
186
+ k_posdef = 2 if share_states else 2 * k_endog
180
187
 
181
188
  obs_state_idx = np.zeros(k_states)
182
189
  obs_state_idx[slice(0, k_states, 2)] = 1
@@ -190,21 +197,26 @@ class CycleComponent(Component):
190
197
  combine_hidden_states=True,
191
198
  obs_state_idxs=obs_state_idx,
192
199
  observed_state_names=observed_state_names,
200
+ share_states=share_states,
193
201
  )
194
202
 
195
203
  def make_symbolic_graph(self) -> None:
204
+ k_endog = self.k_endog
205
+ k_endog_effective = 1 if self.share_states else k_endog
206
+
196
207
  Z = np.array([1.0, 0.0]).reshape((1, -1))
197
- design_matrix = block_diag(*[Z for _ in range(self.k_endog)])
208
+ design_matrix = block_diag(*[Z for _ in range(k_endog_effective)])
198
209
  self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
199
210
 
200
211
  # selection matrix R defines structure of innovations (always identity for cycle components)
201
212
  # when innovations=False, state cov Q=0, hence R @ Q @ R.T = 0
202
213
  R = np.eye(2) # 2x2 identity for each cycle component
203
- selection_matrix = block_diag(*[R for _ in range(self.k_endog)])
214
+ selection_matrix = block_diag(*[R for _ in range(k_endog_effective)])
204
215
  self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
205
216
 
206
217
  init_state = self.make_and_register_variable(
207
- f"{self.name}", shape=(self.k_endog, 2) if self.k_endog > 1 else (self.k_states,)
218
+ f"params_{self.name}",
219
+ shape=(k_endog_effective, 2) if k_endog_effective > 1 else (self.k_states,),
208
220
  )
209
221
  self.ssm["initial_state", :] = init_state.ravel()
210
222
 
@@ -219,19 +231,19 @@ class CycleComponent(Component):
219
231
  rho = 1
220
232
 
221
233
  T = rho * _frequency_transition_block(lamb, j=1)
222
- transition = block_diag(*[T for _ in range(self.k_endog)])
234
+ transition = block_diag(*[T for _ in range(k_endog_effective)])
223
235
  self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
224
236
 
225
237
  if self.innovations:
226
- if self.k_endog == 1:
238
+ if k_endog_effective == 1:
227
239
  sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
228
240
  self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
229
241
  else:
230
242
  sigma_cycle = self.make_and_register_variable(
231
- f"sigma_{self.name}", shape=(self.k_endog,)
243
+ f"sigma_{self.name}", shape=(k_endog_effective,)
232
244
  )
233
245
  state_cov = block_diag(
234
- *[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
246
+ *[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(k_endog_effective)]
235
247
  )
236
248
  self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
237
249
  else:
@@ -239,33 +251,41 @@ class CycleComponent(Component):
239
251
  self.ssm["state_cov", :, :] = pt.zeros((self.k_posdef, self.k_posdef))
240
252
 
241
253
  def populate_component_properties(self):
242
- self.state_names = [
243
- f"{f}_{self.name}[{var_name}]" if self.k_endog > 1 else f"{f}_{self.name}"
244
- for var_name in self.observed_state_names
245
- for f in ["Cos", "Sin"]
246
- ]
254
+ k_endog = self.k_endog
255
+ k_endog_effective = 1 if self.share_states else k_endog
256
+
257
+ base_names = [f"{f}_{self.name}" for f in ["Cos", "Sin"]]
258
+
259
+ if self.share_states:
260
+ self.state_names = [f"{name}[shared]" for name in base_names]
261
+ else:
262
+ self.state_names = [
263
+ f"{name}[{var_name}]" if k_endog_effective > 1 else name
264
+ for var_name in self.observed_state_names
265
+ for name in base_names
266
+ ]
247
267
 
248
- self.param_names = [f"{self.name}"]
268
+ self.param_names = [f"params_{self.name}"]
249
269
 
250
- if self.k_endog == 1:
251
- self.param_dims = {self.name: (f"state_{self.name}",)}
252
- self.coords = {f"state_{self.name}": self.state_names}
270
+ if k_endog_effective == 1:
271
+ self.param_dims = {f"params_{self.name}": (f"state_{self.name}",)}
272
+ self.coords = {f"state_{self.name}": base_names}
253
273
  self.param_info = {
254
- f"{self.name}": {
274
+ f"params_{self.name}": {
255
275
  "shape": (2,),
256
276
  "constraints": None,
257
277
  "dims": (f"state_{self.name}",),
258
278
  }
259
279
  }
260
280
  else:
261
- self.param_dims = {self.name: (f"endog_{self.name}", f"state_{self.name}")}
281
+ self.param_dims = {f"params_{self.name}": (f"endog_{self.name}", f"state_{self.name}")}
262
282
  self.coords = {
263
283
  f"state_{self.name}": [f"Cos_{self.name}", f"Sin_{self.name}"],
264
284
  f"endog_{self.name}": self.observed_state_names,
265
285
  }
266
286
  self.param_info = {
267
- f"{self.name}": {
268
- "shape": (self.k_endog, 2),
287
+ f"params_{self.name}": {
288
+ "shape": (k_endog_effective, 2),
269
289
  "constraints": None,
270
290
  "dims": (f"endog_{self.name}", f"state_{self.name}"),
271
291
  }
@@ -274,22 +294,22 @@ class CycleComponent(Component):
274
294
  if self.estimate_cycle_length:
275
295
  self.param_names += [f"length_{self.name}"]
276
296
  self.param_info[f"length_{self.name}"] = {
277
- "shape": () if self.k_endog == 1 else (self.k_endog,),
297
+ "shape": () if k_endog_effective == 1 else (k_endog_effective,),
278
298
  "constraints": "Positive, non-zero",
279
- "dims": None if self.k_endog == 1 else f"endog_{self.name}",
299
+ "dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
280
300
  }
281
301
 
282
302
  if self.dampen:
283
303
  self.param_names += [f"dampening_factor_{self.name}"]
284
304
  self.param_info[f"dampening_factor_{self.name}"] = {
285
- "shape": () if self.k_endog == 1 else (self.k_endog,),
305
+ "shape": () if k_endog_effective == 1 else (k_endog_effective,),
286
306
  "constraints": "0 < x ≤ 1",
287
- "dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
307
+ "dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
288
308
  }
289
309
 
290
310
  if self.innovations:
291
311
  self.param_names += [f"sigma_{self.name}"]
292
- if self.k_endog == 1:
312
+ if k_endog_effective == 1:
293
313
  self.param_info[f"sigma_{self.name}"] = {
294
314
  "shape": (),
295
315
  "constraints": "Positive",
@@ -298,7 +318,7 @@ class CycleComponent(Component):
298
318
  else:
299
319
  self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
300
320
  self.param_info[f"sigma_{self.name}"] = {
301
- "shape": (self.k_endog,),
321
+ "shape": (k_endog_effective,),
302
322
  "constraints": "Positive",
303
323
  "dims": (f"endog_{self.name}",),
304
324
  }
@@ -13,13 +13,11 @@ class LevelTrendComponent(Component):
13
13
  Parameters
14
14
  ----------
15
15
  order : int
16
-
17
16
  Number of time derivatives of the trend to include in the model. For example, when order=3, the trend will
18
17
  be of the form ``y = a + b * t + c * t ** 2``, where the coefficients ``a, b, c`` come from the initial
19
18
  state values.
20
19
 
21
20
  innovations_order : int or sequence of int, optional
22
-
23
21
  The number of stochastic innovations to include in the model. By default, ``innovations_order = order``
24
22
 
25
23
  name : str, default "level_trend"
@@ -28,6 +26,11 @@ class LevelTrendComponent(Component):
28
26
  observed_state_names : list[str] | None, default None
29
27
  List of strings for observed state labels. If None, defaults to ["data"].
30
28
 
29
+ share_states: bool, default False
30
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
31
+ states, which are observed by all observed states. If False, each observed state has its own set of
32
+ latent states. This argument has no effect if `k_endog` is 1.
33
+
31
34
  Notes
32
35
  -----
33
36
  This class implements the level and trend components of the general structural time series model. In the most
@@ -120,7 +123,10 @@ class LevelTrendComponent(Component):
120
123
  innovations_order: int | list[int] | None = None,
121
124
  name: str = "level_trend",
122
125
  observed_state_names: list[str] | None = None,
126
+ share_states: bool = False,
123
127
  ):
128
+ self.share_states = share_states
129
+
124
130
  if innovations_order is None:
125
131
  innovations_order = order
126
132
 
@@ -156,37 +162,51 @@ class LevelTrendComponent(Component):
156
162
  super().__init__(
157
163
  name,
158
164
  k_endog=k_endog,
159
- k_states=k_states * k_endog,
160
- k_posdef=k_posdef * k_endog,
165
+ k_states=k_states * k_endog if not share_states else k_states,
166
+ k_posdef=k_posdef * k_endog if not share_states else k_posdef,
161
167
  observed_state_names=observed_state_names,
162
168
  measurement_error=False,
163
169
  combine_hidden_states=False,
164
- obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
170
+ obs_state_idxs=np.tile(
171
+ np.array([1.0] + [0.0] * (k_states - 1)), k_endog if not share_states else 1
172
+ ),
173
+ share_states=share_states,
165
174
  )
166
175
 
167
176
  def populate_component_properties(self):
168
177
  k_endog = self.k_endog
169
- k_states = self.k_states // k_endog
170
- k_posdef = self.k_posdef // k_endog
178
+ k_endog_effective = 1 if self.share_states else k_endog
179
+
180
+ k_states = self.k_states // k_endog_effective
181
+ k_posdef = self.k_posdef // k_endog_effective
171
182
 
172
183
  name_slice = POSITION_DERIVATIVE_NAMES[:k_states]
173
184
  self.param_names = [f"initial_{self.name}"]
174
185
  base_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
175
- self.state_names = [
176
- f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
177
- ]
186
+
187
+ if self.share_states:
188
+ self.state_names = [f"{name}[{self.name}_shared]" for name in base_names]
189
+ else:
190
+ self.state_names = [
191
+ f"{name}[{obs_name}]"
192
+ for obs_name in self.observed_state_names
193
+ for name in base_names
194
+ ]
195
+
178
196
  self.param_dims = {f"initial_{self.name}": (f"state_{self.name}",)}
179
197
  self.coords = {f"state_{self.name}": base_names}
180
198
 
181
199
  if k_endog > 1:
200
+ self.coords[f"endog_{self.name}"] = self.observed_state_names
201
+
202
+ if k_endog_effective > 1:
182
203
  self.param_dims[f"state_{self.name}"] = (
183
204
  f"endog_{self.name}",
184
205
  f"state_{self.name}",
185
206
  )
186
207
  self.param_dims = {f"initial_{self.name}": (f"endog_{self.name}", f"state_{self.name}")}
187
- self.coords[f"endog_{self.name}"] = self.observed_state_names
188
208
 
189
- shape = (k_endog, k_states) if k_endog > 1 else (k_states,)
209
+ shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,)
190
210
  self.param_info = {f"initial_{self.name}": {"shape": shape, "constraints": None}}
191
211
 
192
212
  if self.k_posdef > 0:
@@ -196,20 +216,23 @@ class LevelTrendComponent(Component):
196
216
  name for name, mask in zip(name_slice, self.innovations_order) if mask
197
217
  ]
198
218
 
199
- self.shock_names = [
200
- f"{name}[{obs_name}]"
201
- for obs_name in self.observed_state_names
202
- for name in base_shock_names
203
- ]
219
+ if self.share_states:
220
+ self.shock_names = [f"{name}[{self.name}_shared]" for name in base_shock_names]
221
+ else:
222
+ self.shock_names = [
223
+ f"{name}[{obs_name}]"
224
+ for obs_name in self.observed_state_names
225
+ for name in base_shock_names
226
+ ]
204
227
 
205
228
  self.param_dims[f"sigma_{self.name}"] = (
206
229
  (f"shock_{self.name}",)
207
- if k_endog == 1
230
+ if k_endog_effective == 1
208
231
  else (f"endog_{self.name}", f"shock_{self.name}")
209
232
  )
210
233
  self.coords[f"shock_{self.name}"] = base_shock_names
211
234
  self.param_info[f"sigma_{self.name}"] = {
212
- "shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
235
+ "shape": (k_posdef,) if k_endog_effective == 1 else (k_endog_effective, k_posdef),
213
236
  "constraints": "Positive",
214
237
  }
215
238
 
@@ -218,12 +241,14 @@ class LevelTrendComponent(Component):
218
241
 
219
242
  def make_symbolic_graph(self) -> None:
220
243
  k_endog = self.k_endog
221
- k_states = self.k_states // k_endog
222
- k_posdef = self.k_posdef // k_endog
244
+ k_endog_effective = 1 if self.share_states else k_endog
245
+
246
+ k_states = self.k_states // k_endog_effective
247
+ k_posdef = self.k_posdef // k_endog_effective
223
248
 
224
249
  initial_trend = self.make_and_register_variable(
225
250
  f"initial_{self.name}",
226
- shape=(k_states,) if k_endog == 1 else (k_endog, k_states),
251
+ shape=(k_states,) if k_endog_effective == 1 else (k_endog, k_states),
227
252
  )
228
253
  self.ssm["initial_state", :] = initial_trend.ravel()
229
254
 
@@ -231,27 +256,34 @@ class LevelTrendComponent(Component):
231
256
  T = pt.zeros((k_states, k_states))[triu_idx[0], triu_idx[1]].set(1)
232
257
 
233
258
  self.ssm["transition", :, :] = pt.specify_shape(
234
- pt.linalg.block_diag(*[T for _ in range(k_endog)]), (self.k_states, self.k_states)
259
+ pt.linalg.block_diag(*[T for _ in range(k_endog_effective)]),
260
+ (self.k_states, self.k_states),
235
261
  )
236
262
 
237
263
  R = np.eye(k_states)
238
264
  R = R[:, self.innovations_order]
239
265
 
240
266
  self.ssm["selection", :, :] = pt.specify_shape(
241
- pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
267
+ pt.linalg.block_diag(*[R for _ in range(k_endog_effective)]),
268
+ (self.k_states, self.k_posdef),
242
269
  )
243
270
 
244
271
  Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1))
245
272
 
246
- self.ssm["design", :, :] = pt.specify_shape(
247
- pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
248
- )
273
+ if self.share_states:
274
+ self.ssm["design", :, :] = pt.specify_shape(
275
+ pt.join(0, *[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
276
+ )
277
+ else:
278
+ self.ssm["design", :, :] = pt.specify_shape(
279
+ pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
280
+ )
249
281
 
250
282
  if k_posdef > 0:
251
283
  sigma_trend = self.make_and_register_variable(
252
284
  f"sigma_{self.name}",
253
- shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
285
+ shape=(k_posdef,) if k_endog_effective == 1 else (k_endog, k_posdef),
254
286
  )
255
- diag_idx = np.diag_indices(k_posdef * k_endog)
287
+ diag_idx = np.diag_indices(k_posdef * k_endog_effective)
256
288
  idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
257
289
  self.ssm[idx] = (sigma_trend**2).ravel()
@@ -17,6 +17,10 @@ class MeasurementError(Component):
17
17
  Name of the measurement error component. Default is "MeasurementError".
18
18
  observed_state_names : list[str] | None, optional
19
19
  Names of the observed variables. If None, defaults to ["data"].
20
+ share_states: bool, default False
21
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
22
+ states, which are observed by all observed states. If False, each observed state has its own set of
23
+ latent states. This argument has no effect if `k_endog` is 1.
20
24
 
21
25
  Notes
22
26
  -----
@@ -93,11 +97,16 @@ class MeasurementError(Component):
93
97
  """
94
98
 
95
99
  def __init__(
96
- self, name: str = "MeasurementError", observed_state_names: list[str] | None = None
100
+ self,
101
+ name: str = "MeasurementError",
102
+ observed_state_names: list[str] | None = None,
103
+ share_states: bool = False,
97
104
  ):
98
105
  if observed_state_names is None:
99
106
  observed_state_names = ["data"]
100
107
 
108
+ self.share_states = share_states
109
+
101
110
  k_endog = len(observed_state_names)
102
111
  k_states = 0
103
112
  k_posdef = 0
@@ -110,28 +119,36 @@ class MeasurementError(Component):
110
119
  measurement_error=True,
111
120
  combine_hidden_states=False,
112
121
  observed_state_names=observed_state_names,
122
+ share_states=share_states,
113
123
  )
114
124
 
115
125
  def populate_component_properties(self):
126
+ k_endog = self.k_endog
127
+ k_endog_effective = 1 if self.share_states else k_endog
128
+
116
129
  self.param_names = [f"sigma_{self.name}"]
117
130
  self.param_dims = {}
118
131
  self.coords = {}
119
132
 
120
- if self.k_endog > 1:
133
+ if k_endog_effective > 1:
121
134
  self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
122
135
  self.coords[f"endog_{self.name}"] = self.observed_state_names
123
136
 
124
137
  self.param_info = {
125
138
  f"sigma_{self.name}": {
126
- "shape": (self.k_endog,) if self.k_endog > 1 else (),
139
+ "shape": (k_endog_effective,) if k_endog_effective > 1 else (),
127
140
  "constraints": "Positive",
128
- "dims": (f"endog_{self.name}",) if self.k_endog > 1 else None,
141
+ "dims": (f"endog_{self.name}",) if k_endog_effective > 1 else None,
129
142
  }
130
143
  }
131
144
 
132
145
  def make_symbolic_graph(self) -> None:
133
- sigma_shape = () if self.k_endog == 1 else (self.k_endog,)
146
+ k_endog = self.k_endog
147
+ k_endog_effective = 1 if self.share_states else k_endog
148
+
149
+ sigma_shape = () if k_endog_effective == 1 else (k_endog_effective,)
134
150
  error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
151
+
135
152
  diag_idx = np.diag_indices(self.k_endog)
136
153
  idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]
137
154
  self.ssm[idx] = error_sigma**2
@@ -31,6 +31,11 @@ class RegressionComponent(Component):
31
31
  Whether to include stochastic innovations in the regression coefficients,
32
32
  allowing them to vary over time. If True, coefficients follow a random walk.
33
33
 
34
+ share_states: bool, default False
35
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
36
+ states, which are observed by all observed states. If False, each observed state has its own set of
37
+ latent states.
38
+
34
39
  Notes
35
40
  -----
36
41
  This component implements regression with exogenous variables in a structural time series
@@ -107,7 +112,10 @@ class RegressionComponent(Component):
107
112
  state_names: list[str] | None = None,
108
113
  observed_state_names: list[str] | None = None,
109
114
  innovations=False,
115
+ share_states: bool = False,
110
116
  ):
117
+ self.share_states = share_states
118
+
111
119
  if observed_state_names is None:
112
120
  observed_state_names = ["data"]
113
121
 
@@ -121,9 +129,10 @@ class RegressionComponent(Component):
121
129
  super().__init__(
122
130
  name=name,
123
131
  k_endog=k_endog,
124
- k_states=k_states * k_endog,
125
- k_posdef=k_posdef * k_endog,
132
+ k_states=k_states * k_endog if not share_states else k_states,
133
+ k_posdef=k_posdef * k_endog if not share_states else k_posdef,
126
134
  state_names=self.state_names,
135
+ share_states=share_states,
127
136
  observed_state_names=observed_state_names,
128
137
  measurement_error=False,
129
138
  combine_hidden_states=False,
@@ -153,10 +162,12 @@ class RegressionComponent(Component):
153
162
 
154
163
  def make_symbolic_graph(self) -> None:
155
164
  k_endog = self.k_endog
156
- k_states = self.k_states // k_endog
165
+ k_endog_effective = 1 if self.share_states else k_endog
166
+
167
+ k_states = self.k_states // k_endog_effective
157
168
 
158
169
  betas = self.make_and_register_variable(
159
- f"beta_{self.name}", shape=(k_endog, k_states) if k_endog > 1 else (k_states,)
170
+ f"beta_{self.name}", shape=(k_endog, k_states) if k_endog_effective > 1 else (k_states,)
160
171
  )
161
172
  regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states))
162
173
 
@@ -164,43 +175,61 @@ class RegressionComponent(Component):
164
175
  self.ssm["transition", :, :] = pt.eye(self.k_states)
165
176
  self.ssm["selection", :, :] = pt.eye(self.k_states)
166
177
 
167
- Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)])
168
- self.ssm["design"] = pt.specify_shape(
169
- Z, (None, k_endog, regression_data.type.shape[1] * k_endog)
170
- )
178
+ if self.share_states:
179
+ self.ssm["design"] = pt.specify_shape(
180
+ pt.join(1, *[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]),
181
+ (None, k_endog, self.k_states),
182
+ )
183
+ else:
184
+ Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)])
185
+ self.ssm["design"] = pt.specify_shape(
186
+ Z, (None, k_endog, regression_data.type.shape[1] * k_endog)
187
+ )
171
188
 
172
189
  if self.innovations:
173
190
  sigma_beta = self.make_and_register_variable(
174
- f"sigma_beta_{self.name}", (k_states,) if k_endog == 1 else (k_endog, k_states)
191
+ f"sigma_beta_{self.name}",
192
+ (k_states,) if k_endog_effective == 1 else (k_endog, k_states),
175
193
  )
176
194
  row_idx, col_idx = np.diag_indices(self.k_states)
177
195
  self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2
178
196
 
179
197
  def populate_component_properties(self) -> None:
180
198
  k_endog = self.k_endog
181
- k_states = self.k_states // k_endog
199
+ k_endog_effective = 1 if self.share_states else k_endog
200
+
201
+ k_states = self.k_states // k_endog_effective
182
202
 
183
- self.shock_names = self.state_names
203
+ if self.share_states:
204
+ self.shock_names = [f"{state_name}_shared" for state_name in self.state_names]
205
+ else:
206
+ self.shock_names = self.state_names
184
207
 
185
208
  self.param_names = [f"beta_{self.name}"]
186
209
  self.data_names = [f"data_{self.name}"]
187
210
  self.param_dims = {
188
211
  f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}")
189
- if k_endog > 1
212
+ if k_endog_effective > 1
190
213
  else (f"state_{self.name}",)
191
214
  }
192
215
 
193
216
  base_names = self.state_names
194
- self.state_names = [
195
- f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
196
- ]
217
+
218
+ if self.share_states:
219
+ self.state_names = [f"{name}[{self.name}_shared]" for name in base_names]
220
+ else:
221
+ self.state_names = [
222
+ f"{name}[{obs_name}]"
223
+ for obs_name in self.observed_state_names
224
+ for name in base_names
225
+ ]
197
226
 
198
227
  self.param_info = {
199
228
  f"beta_{self.name}": {
200
- "shape": (k_endog, k_states) if k_endog > 1 else (k_states,),
229
+ "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,),
201
230
  "constraints": None,
202
231
  "dims": (f"endog_{self.name}", f"state_{self.name}")
203
- if k_endog > 1
232
+ if k_endog_effective > 1
204
233
  else (f"state_{self.name}",),
205
234
  },
206
235
  }
@@ -223,6 +252,6 @@ class RegressionComponent(Component):
223
252
  "shape": (k_states,),
224
253
  "constraints": "Positive",
225
254
  "dims": (f"state_{self.name}",)
226
- if k_endog == 1
255
+ if k_endog_effective == 1
227
256
  else (f"endog_{self.name}", f"state_{self.name}"),
228
257
  }