pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. pymc_extras/distributions/timeseries.py +10 -10
  2. pymc_extras/inference/dadvi/dadvi.py +14 -83
  3. pymc_extras/inference/laplace_approx/laplace.py +187 -159
  4. pymc_extras/inference/pathfinder/pathfinder.py +12 -7
  5. pymc_extras/inference/smc/sampling.py +2 -2
  6. pymc_extras/model/marginal/distributions.py +4 -2
  7. pymc_extras/model/marginal/marginal_model.py +12 -2
  8. pymc_extras/prior.py +3 -3
  9. pymc_extras/statespace/core/properties.py +276 -0
  10. pymc_extras/statespace/core/statespace.py +182 -45
  11. pymc_extras/statespace/filters/distributions.py +19 -34
  12. pymc_extras/statespace/filters/kalman_filter.py +13 -12
  13. pymc_extras/statespace/filters/kalman_smoother.py +2 -2
  14. pymc_extras/statespace/models/DFM.py +179 -168
  15. pymc_extras/statespace/models/ETS.py +177 -151
  16. pymc_extras/statespace/models/SARIMAX.py +149 -152
  17. pymc_extras/statespace/models/VARMAX.py +134 -145
  18. pymc_extras/statespace/models/__init__.py +8 -1
  19. pymc_extras/statespace/models/structural/__init__.py +30 -8
  20. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  21. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  22. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  23. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  24. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  25. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  26. pymc_extras/statespace/models/structural/core.py +397 -286
  27. pymc_extras/statespace/models/utilities.py +5 -20
  28. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
  29. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
  30. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  31. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,20 @@
1
1
  from collections.abc import Sequence
2
- from typing import Any
3
2
 
4
3
  import numpy as np
5
4
  import pytensor.tensor as pt
6
5
 
7
6
  from pytensor.compile.mode import Mode
8
- from pytensor.tensor.slinalg import solve_discrete_lyapunov
9
-
7
+ from pytensor.tensor.linalg import solve_discrete_lyapunov
8
+
9
+ from pymc_extras.statespace.core.properties import (
10
+ Coord,
11
+ Data,
12
+ Parameter,
13
+ Shock,
14
+ State,
15
+ )
10
16
  from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
11
17
  from pymc_extras.statespace.models.utilities import (
12
- make_default_coords,
13
18
  make_harvey_state_names,
14
19
  make_SARIMA_transition_matrix,
15
20
  validate_names,
@@ -20,7 +25,6 @@ from pymc_extras.statespace.utils.constants import (
20
25
  AR_PARAM_DIM,
21
26
  EXOG_STATE_DIM,
22
27
  MA_PARAM_DIM,
23
- OBS_STATE_DIM,
24
28
  SARIMAX_STATE_STRUCTURES,
25
29
  SEASONAL_AR_PARAM_DIM,
26
30
  SEASONAL_MA_PARAM_DIM,
@@ -132,7 +136,7 @@ class BayesianSARIMAX(PyMCStateSpace):
132
136
  self,
133
137
  order: tuple[int, int, int],
134
138
  seasonal_order: tuple[int, int, int, int] | None = None,
135
- exog_state_names: list[str] | None = None,
139
+ exog_state_names: Sequence[str] | None = None,
136
140
  stationary_initialization: bool = True,
137
141
  filter_type: str = "standard",
138
142
  state_structure: str = "fast",
@@ -163,7 +167,7 @@ class BayesianSARIMAX(PyMCStateSpace):
163
167
  possible for the seasonal lags and the ARIMA lags to overlap, for example if P <= p. In this case, an error
164
168
  will be raised.
165
169
 
166
- exog_state_names : list[str], optional
170
+ exog_state_names : Sequence of str, optional
167
171
  Names of the exogenous state variables.
168
172
 
169
173
  stationary_initialization : bool, default True
@@ -213,7 +217,7 @@ class BayesianSARIMAX(PyMCStateSpace):
213
217
  ) # Not sure if this adds anything
214
218
  k_exog = len(exog_state_names) if exog_state_names is not None else 0
215
219
 
216
- self.exog_state_names = exog_state_names
220
+ self.exog_state_names = tuple(exog_state_names) if exog_state_names is not None else None
217
221
  self.k_exog = k_exog
218
222
 
219
223
  self.P, self.D, self.Q, self.S = seasonal_order
@@ -268,166 +272,159 @@ class BayesianSARIMAX(PyMCStateSpace):
268
272
  )
269
273
  self._needs_exog_data = self.k_exog > 0
270
274
 
271
- @property
272
- def param_names(self):
273
- names = [
274
- "x0",
275
- "P0",
276
- "ar_params",
277
- "ma_params",
278
- "seasonal_ar_params",
279
- "seasonal_ma_params",
280
- "beta_exog",
281
- "sigma_state",
282
- "sigma_obs",
283
- ]
284
- if self.stationary_initialization:
285
- names.remove("P0")
286
- names.remove("x0")
287
- if self.p == 0:
288
- names.remove("ar_params")
289
- if self.P == 0:
290
- names.remove("seasonal_ar_params")
291
- if self.q == 0:
292
- names.remove("ma_params")
293
- if self.Q == 0:
294
- names.remove("seasonal_ma_params")
295
- if self.k_exog == 0:
296
- names.remove("beta_exog")
297
- if not self.measurement_error:
298
- names.remove("sigma_obs")
299
-
300
- return names
301
-
302
- @property
303
- def data_info(self) -> dict[str, dict[str, Any]]:
304
- info = {
305
- "exogenous_data": {
306
- "dims": (TIME_DIM, EXOG_STATE_DIM),
307
- "shape": (None, self.k_exog),
308
- }
309
- }
310
-
311
- return {name: info[name] for name in self.data_names}
312
-
313
- @property
314
- def param_info(self) -> dict[str, dict[str, Any]]:
315
- info = {
316
- "x0": {
317
- "shape": (self.k_states,),
318
- "constraints": None,
319
- },
320
- "P0": {
321
- "shape": (self.k_states, self.k_states),
322
- "constraints": "Positive Semi-definite",
323
- },
324
- "sigma_obs": {
325
- "shape": () if self.k_endog == 1 else (self.k_endog,),
326
- "constraints": "Positive",
327
- },
328
- "sigma_state": {
329
- "shape": () if self.k_posdef == 1 else (self.k_posdef,),
330
- "constraints": "Positive",
331
- },
332
- "ar_params": {
333
- "shape": (self.p,),
334
- "constraints": "None",
335
- },
336
- "ma_params": {
337
- "shape": (self.q,),
338
- "constraints": "None",
339
- },
340
- "seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
341
- "seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
342
- "beta_exog": {"shape": (self.k_exog,), "constraints": "None"},
343
- }
344
-
345
- for name in self.param_names:
346
- info[name]["dims"] = self.param_dims[name]
347
-
348
- return {name: info[name] for name in self.param_names}
349
-
350
- @property
351
- def state_names(self):
352
- if self.state_structure == "fast":
353
- p, d, q = self.p, self.d, self.q
354
- P, D, Q, S = self.P, self.D, self.Q, self.S
355
- states = make_harvey_state_names(p, d, q, P, D, Q, S)
275
+ def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
276
+ k_states = self.k_states
277
+ parameters = []
278
+
279
+ if not self.stationary_initialization:
280
+ parameters.append(
281
+ Parameter(
282
+ name="x0",
283
+ shape=(k_states,),
284
+ dims=(ALL_STATE_DIM,),
285
+ constraints=None,
286
+ )
287
+ )
288
+ parameters.append(
289
+ Parameter(
290
+ name="P0",
291
+ shape=(k_states, k_states),
292
+ dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM),
293
+ constraints="Positive Semi-definite",
294
+ )
295
+ )
296
+
297
+ if self.p > 0:
298
+ parameters.append(
299
+ Parameter(
300
+ name="ar_params",
301
+ shape=(self.p,),
302
+ dims=(AR_PARAM_DIM,),
303
+ constraints=None,
304
+ )
305
+ )
306
+
307
+ if self.q > 0:
308
+ parameters.append(
309
+ Parameter(
310
+ name="ma_params",
311
+ shape=(self.q,),
312
+ dims=(MA_PARAM_DIM,),
313
+ constraints=None,
314
+ )
315
+ )
316
+
317
+ if self.P > 0:
318
+ parameters.append(
319
+ Parameter(
320
+ name="seasonal_ar_params",
321
+ shape=(self.P,),
322
+ dims=(SEASONAL_AR_PARAM_DIM,),
323
+ constraints=None,
324
+ )
325
+ )
326
+
327
+ if self.Q > 0:
328
+ parameters.append(
329
+ Parameter(
330
+ name="seasonal_ma_params",
331
+ shape=(self.Q,),
332
+ dims=(SEASONAL_MA_PARAM_DIM,),
333
+ constraints=None,
334
+ )
335
+ )
356
336
 
337
+ if self.k_exog > 0:
338
+ parameters.append(
339
+ Parameter(
340
+ name="beta_exog",
341
+ shape=(self.k_exog,),
342
+ dims=(EXOG_STATE_DIM,),
343
+ constraints=None,
344
+ )
345
+ )
346
+
347
+ parameters.append(
348
+ Parameter(
349
+ name="sigma_state",
350
+ shape=(),
351
+ dims=None,
352
+ constraints="Positive",
353
+ )
354
+ )
355
+
356
+ if self.measurement_error:
357
+ parameters.append(
358
+ Parameter(
359
+ name="sigma_obs",
360
+ shape=(),
361
+ dims=None,
362
+ constraints="Positive",
363
+ )
364
+ )
365
+
366
+ return tuple(parameters)
367
+
368
+ def set_states(self) -> State | tuple[State, ...] | None:
369
+ if self.state_structure == "fast":
370
+ state_names = make_harvey_state_names(
371
+ self.p, self.d, self.q, self.P, self.D, self.Q, self.S
372
+ )
357
373
  elif self.state_structure == "interpretable":
358
- states = ["data"]
374
+ state_names = ["data"]
359
375
  if self.p > 0:
360
- states += [f"L{i + 1}.data" for i in range(self._p_max - 1)]
361
- states += ["innovations"]
376
+ state_names += [f"L{i + 1}.data" for i in range(self._p_max - 1)]
377
+ state_names += ["innovations"]
362
378
  if self.q > 0:
363
- states += [f"L{i + 1}.innovations" for i in range(self._q_max - 1)]
379
+ state_names += [f"L{i + 1}.innovations" for i in range(self._q_max - 1)]
364
380
  else:
365
381
  raise NotImplementedError()
366
382
 
367
- return states
383
+ hidden_states = [State(name=name, observed=False) for name in state_names]
384
+
385
+ # The first state is the observed state
386
+ observed_state = State(name=state_names[0], observed=True)
387
+
388
+ return *hidden_states, observed_state
389
+
390
+ def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
391
+ return Shock(name="innovation")
368
392
 
369
- @property
370
- def data_names(self) -> list[str]:
393
+ def set_data_info(self) -> tuple[Data, ...] | None:
371
394
  if self.k_exog > 0:
372
- return ["exogenous_data"]
373
- return []
374
-
375
- @property
376
- def observed_states(self):
377
- return [self.state_names[0]]
378
-
379
- @property
380
- def shock_names(self):
381
- return ["innovation"]
382
-
383
- @property
384
- def param_dims(self):
385
- coord_map = {
386
- "x0": (ALL_STATE_DIM,),
387
- "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
388
- "sigma_obs": (OBS_STATE_DIM,),
389
- "sigma_state": (OBS_STATE_DIM,),
390
- "ar_params": (AR_PARAM_DIM,),
391
- "ma_params": (MA_PARAM_DIM,),
392
- "seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
393
- "seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
394
- "beta_exog": (EXOG_STATE_DIM,),
395
- }
396
- if self.k_endog == 1:
397
- coord_map["sigma_state"] = None
398
- coord_map["sigma_obs"] = None
399
- if not self.measurement_error:
400
- del coord_map["sigma_obs"]
401
- if self.p == 0:
402
- del coord_map["ar_params"]
403
- if self.q == 0:
404
- del coord_map["ma_params"]
405
- if self.P == 0:
406
- del coord_map["seasonal_ar_params"]
407
- if self.Q == 0:
408
- del coord_map["seasonal_ma_params"]
409
- if self.k_exog == 0:
410
- del coord_map["beta_exog"]
411
- if self.stationary_initialization:
412
- del coord_map["P0"]
413
- del coord_map["x0"]
395
+ return (
396
+ Data(
397
+ name="exogenous_data",
398
+ shape=(None, self.k_exog),
399
+ dims=(TIME_DIM, EXOG_STATE_DIM),
400
+ is_exogenous=True,
401
+ ),
402
+ )
403
+ return None
414
404
 
415
- return coord_map
405
+ def set_coords(self) -> Coord | tuple[Coord, ...] | None:
406
+ coords = list(self.default_coords())
416
407
 
417
- @property
418
- def coords(self) -> dict[str, Sequence]:
419
- coords = make_default_coords(self)
420
408
  if self.p > 0:
421
- coords.update({AR_PARAM_DIM: list(range(1, self.p + 1))})
409
+ coords.append(Coord(dimension=AR_PARAM_DIM, labels=tuple(range(1, self.p + 1))))
410
+
422
411
  if self.q > 0:
423
- coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
412
+ coords.append(Coord(dimension=MA_PARAM_DIM, labels=tuple(range(1, self.q + 1))))
413
+
424
414
  if self.P > 0:
425
- coords.update({SEASONAL_AR_PARAM_DIM: list(range(1, self.P + 1))})
415
+ coords.append(
416
+ Coord(dimension=SEASONAL_AR_PARAM_DIM, labels=tuple(range(1, self.P + 1)))
417
+ )
418
+
426
419
  if self.Q > 0:
427
- coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
420
+ coords.append(
421
+ Coord(dimension=SEASONAL_MA_PARAM_DIM, labels=tuple(range(1, self.Q + 1)))
422
+ )
423
+
428
424
  if self.k_exog > 0:
429
- coords.update({EXOG_STATE_DIM: self.exog_state_names})
430
- return coords
425
+ coords.append(Coord(dimension=EXOG_STATE_DIM, labels=tuple(self.exog_state_names)))
426
+
427
+ return tuple(coords)
431
428
 
432
429
  def _stationary_initialization(self):
433
430
  # Solve for matrix quadratic for P0