pymc-extras 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. pymc_extras/inference/laplace_approx/laplace.py +2 -2
  2. pymc_extras/inference/pathfinder/pathfinder.py +1 -1
  3. pymc_extras/prior.py +3 -3
  4. pymc_extras/statespace/core/properties.py +276 -0
  5. pymc_extras/statespace/core/statespace.py +180 -44
  6. pymc_extras/statespace/filters/distributions.py +12 -29
  7. pymc_extras/statespace/filters/kalman_filter.py +1 -1
  8. pymc_extras/statespace/models/DFM.py +179 -168
  9. pymc_extras/statespace/models/ETS.py +177 -151
  10. pymc_extras/statespace/models/SARIMAX.py +149 -152
  11. pymc_extras/statespace/models/VARMAX.py +134 -145
  12. pymc_extras/statespace/models/__init__.py +8 -1
  13. pymc_extras/statespace/models/structural/__init__.py +30 -8
  14. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  15. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  16. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  17. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  18. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  19. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  20. pymc_extras/statespace/models/structural/core.py +397 -286
  21. pymc_extras/statespace/models/utilities.py +5 -20
  22. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +3 -3
  23. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +25 -24
  24. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  25. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -19,6 +19,22 @@ from rich.box import SIMPLE_HEAD
19
19
  from rich.console import Console
20
20
  from rich.table import Table
21
21
 
22
+ from pymc_extras.statespace.core.properties import (
23
+ Coord,
24
+ CoordInfo,
25
+ Data,
26
+ DataInfo,
27
+ Parameter,
28
+ ParameterInfo,
29
+ Shock,
30
+ ShockInfo,
31
+ State,
32
+ StateInfo,
33
+ SymbolicData,
34
+ SymbolicDataInfo,
35
+ SymbolicVariable,
36
+ SymbolicVariableInfo,
37
+ )
22
38
  from pymc_extras.statespace.core.representation import PytensorRepresentation
23
39
  from pymc_extras.statespace.filters import (
24
40
  KalmanSmoother,
@@ -69,6 +85,22 @@ def _verify_group(group):
69
85
  raise ValueError(f'Argument "group" must be one of "prior" or "posterior", found {group}')
70
86
 
71
87
 
88
+ def _validate_property(props, property_name, expected_type):
89
+ if isinstance(props, expected_type) or props is None:
90
+ return
91
+ elif not isinstance(props, tuple | list):
92
+ raise TypeError(
93
+ f"The {property_name} property must be a {expected_type.__name__} or a "
94
+ f"list/tuple of {expected_type.__name__} instances."
95
+ )
96
+
97
+ if not all(isinstance(prop, expected_type) for prop in props):
98
+ raise TypeError(
99
+ f"All elements of the {property_name} property must be instances of "
100
+ f"{expected_type.__name__}."
101
+ )
102
+
103
+
72
104
  class PyMCStateSpace:
73
105
  r"""
74
106
  Base class for Linear Gaussian Statespace models in PyMC.
@@ -236,8 +268,8 @@ class PyMCStateSpace:
236
268
  self._fit_exog_data: dict[str, dict] = {}
237
269
 
238
270
  self._needs_exog_data = None
239
- self._name_to_variable = {}
240
- self._name_to_data = {}
271
+ self._tensor_variable_info = SymbolicVariableInfo()
272
+ self._tensor_data_info = SymbolicDataInfo()
241
273
 
242
274
  self.k_endog = k_endog
243
275
  self.k_states = k_states
@@ -245,6 +277,8 @@ class PyMCStateSpace:
245
277
  self.measurement_error = measurement_error
246
278
  self.mode = mode
247
279
 
280
+ self._populate_properties()
281
+
248
282
  # All models contain a state space representation and a Kalman filter
249
283
  self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
250
284
 
@@ -271,17 +305,107 @@ class PyMCStateSpace:
271
305
  console = Console()
272
306
  console.print(self.requirement_table)
273
307
 
308
+ def _populate_properties(self) -> None:
309
+ self._set_parameters()
310
+ self._set_states()
311
+ self._set_shocks()
312
+ self._set_coords()
313
+ self._set_data_info()
314
+
315
+ def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
316
+ """
317
+ Provides parameter metadata to the model.
318
+
319
+ Optional. Default implementation sets an empty ParameterInfo.
320
+ Child classes can override to define model parameters.
321
+ """
322
+ return
323
+
324
+ def _set_parameters(self) -> None:
325
+ param_list = self.set_parameters()
326
+ self._param_info = ParameterInfo(parameters=param_list)
327
+
328
+ def set_states(self) -> State | tuple[State, ...] | None:
329
+ """
330
+ Provide state metadata to the model.
331
+
332
+ Optional. Default implementation creates generic state names (state_0, state_1, ...).
333
+ Child classes can override to provide meaningful state names.
334
+ """
335
+ hidden_states = [
336
+ State(name=f"hidden_state_{name}", observed=False) for name in range(self.k_states or 0)
337
+ ]
338
+ observed_states = [
339
+ State(name=f"observed_state_{name}", observed=True) for name in range(self.k_endog or 0)
340
+ ]
341
+ return *hidden_states, *observed_states
342
+
343
+ def _set_states(self) -> None:
344
+ states = self.set_states()
345
+ _validate_property(states, "states", State)
346
+
347
+ if isinstance(states, State):
348
+ states = (states,)
349
+
350
+ self._state_info = StateInfo(states=states)
351
+
352
+ def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
353
+ """
354
+ Provide shock metadata to the model.
355
+
356
+ Optional. Default implementation creates generic shock names (shock_0, shock_1, ...).
357
+ Child classes can override to provide meaningful shock names.
358
+ """
359
+ return tuple(Shock(name=f"shock_{i}") for i in range(self.k_posdef))
360
+
361
+ def _set_shocks(self) -> None:
362
+ shocks = self.set_shocks()
363
+ _validate_property(shocks, "shocks", Shock)
364
+ if isinstance(shocks, Shock):
365
+ shocks = (shocks,)
366
+ self._shock_info = ShockInfo(shocks=shocks)
367
+
368
+ def default_coords(self) -> tuple[Coord, ...]:
369
+ return CoordInfo.default_coords_from_model(self).items
370
+
371
+ def set_coords(self) -> Coord | tuple[Coord, ...] | None:
372
+ """
373
+ Provide coordinates to the model.
374
+
375
+ Optional. Default implementation sets an empty CoordInfo.
376
+ Child classes can override to provide model-specific coordinates.
377
+ """
378
+ return self.default_coords()
379
+
380
+ def _set_coords(self) -> None:
381
+ coords = self.set_coords()
382
+ _validate_property(coords, "coords", Coord)
383
+ if isinstance(coords, Coord):
384
+ coords = (coords,)
385
+ self._coords_info = CoordInfo(coords=coords)
386
+
387
+ def set_data_info(self) -> Data | tuple[Data, ...] | None:
388
+ """
389
+ Provide data_info metadata to the model.
390
+
391
+ Optional. Default implementation sets an empty DataInfo.
392
+ Child classes can override if the model requires exogenous data.
393
+ """
394
+ return
395
+
396
+ def _set_data_info(self) -> None:
397
+ data_info = self.set_data_info()
398
+ _validate_property(data_info, "data_info", Data)
399
+ if isinstance(data_info, Data):
400
+ data_info = (data_info,)
401
+ self._data_info = DataInfo(data=data_info)
402
+
274
403
  def _populate_prior_requirements(self) -> None:
275
404
  """
276
405
  Add requirements about priors needed for the model to a rich table, including their names,
277
406
  shapes, named dimensions, and any parameter constraints.
278
407
  """
279
- # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
280
- # is not true.
281
- try:
282
- if not isinstance(self.param_info, dict):
283
- return
284
- except NotImplementedError:
408
+ if not self.param_info:
285
409
  return
286
410
 
287
411
  if self.requirement_table is None:
@@ -296,10 +420,7 @@ class PyMCStateSpace:
296
420
  """
297
421
  Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
298
422
  """
299
- try:
300
- if not isinstance(self.data_info, dict):
301
- return
302
- except NotImplementedError:
423
+ if not self.data_info:
303
424
  return
304
425
 
305
426
  if self.requirement_table is None:
@@ -374,24 +495,24 @@ class PyMCStateSpace:
374
495
  return self.subbed_ssm
375
496
 
376
497
  @property
377
- def param_names(self) -> list[str]:
498
+ def param_names(self) -> tuple[str, ...]:
378
499
  """
379
500
  Names of model parameters
380
501
 
381
502
  A list of all parameters expected by the model. Each parameter will be sought inside the active PyMC model
382
503
  context when ``build_statespace_graph`` is invoked.
383
504
  """
384
- raise NotImplementedError("The param_names property has not been implemented!")
505
+ return self._param_info.names
385
506
 
386
507
  @property
387
- def data_names(self) -> list[str]:
508
+ def data_names(self) -> tuple[str, ...]:
388
509
  """
389
510
  Names of data variables expected by the model.
390
511
 
391
512
  This does not include the observed data series, which is automatically handled by PyMC. This property only
392
513
  needs to be implemented for models that expect exogenous data.
393
514
  """
394
- return []
515
+ return self._data_info.exogenous_names
395
516
 
396
517
  @property
397
518
  def param_info(self) -> dict[str, dict[str, Any]]:
@@ -406,7 +527,7 @@ class PyMCStateSpace:
406
527
  positive semi-definite, etc)
407
528
  * key: "dims", value: tuple of strings
408
529
  """
409
- raise NotImplementedError("The params_info property has not been implemented!")
530
+ return self._param_info.to_dict()
410
531
 
411
532
  @property
412
533
  def data_info(self) -> dict[str, dict[str, Any]]:
@@ -419,31 +540,30 @@ class PyMCStateSpace:
419
540
  * key: "shape", value: a tuple of integers
420
541
  * key: "dims", value: tuple of strings
421
542
  """
422
- raise NotImplementedError("The data_info property has not been implemented!")
543
+ return self._data_info.to_dict()
423
544
 
424
545
  @property
425
- def state_names(self) -> list[str]:
546
+ def state_names(self) -> tuple[str, ...]:
426
547
  """
427
548
  A k_states length list of strings, associated with the model's hidden states
428
549
 
429
550
  """
430
-
431
- raise NotImplementedError("The state_names property has not been implemented!")
551
+ return self._state_info.unobserved_state_names
432
552
 
433
553
  @property
434
- def observed_states(self) -> list[str]:
554
+ def observed_states(self) -> tuple[str, ...]:
435
555
  """
436
556
  A k_endog length list of strings, associated with the model's observed states
437
557
  """
438
- raise NotImplementedError("The observed_states property has not been implemented!")
558
+ return self._state_info.observed_state_names
439
559
 
440
560
  @property
441
- def shock_names(self) -> list[str]:
561
+ def shock_names(self) -> tuple[str, ...]:
442
562
  """
443
563
  A k_posdef length list of strings, associated with the model's shock processes
444
564
 
445
565
  """
446
- raise NotImplementedError("The shock_names property has not been implemented!")
566
+ return self._shock_info.names
447
567
 
448
568
  @property
449
569
  def default_priors(self) -> dict[str, Callable]:
@@ -464,10 +584,10 @@ class PyMCStateSpace:
464
584
  should come from the default names defined in ``statespace.utils.constants`` for them to be detected by
465
585
  sampling methods.
466
586
  """
467
- raise NotImplementedError("The coords property has not been implemented!")
587
+ return self._coords_info.to_dict()
468
588
 
469
589
  @property
470
- def param_dims(self) -> dict[str, Sequence[str]]:
590
+ def param_dims(self) -> dict[str, tuple[str, ...]]:
471
591
  """
472
592
  Dictionary of named dimensions for each model parameter
473
593
 
@@ -475,8 +595,17 @@ class PyMCStateSpace:
475
595
  PyMC random variable. Dimensions should come from the default names defined in ``statespace.utils.constants``
476
596
  for them to be detected by sampling methods.
477
597
 
598
+ Note: Scalar parameters (with dims=None) are not included in this dictionary.
478
599
  """
479
- raise NotImplementedError("The param_dims property has not been implemented!")
600
+ return {param.name: param.dims for param in self._param_info if param.dims is not None}
601
+
602
+ @property
603
+ def _name_to_variable(self):
604
+ return self._tensor_variable_info.to_dict()
605
+
606
+ @property
607
+ def _name_to_data(self):
608
+ return self._tensor_data_info.to_dict()
480
609
 
481
610
  def add_default_priors(self) -> None:
482
611
  """
@@ -520,14 +649,15 @@ class PyMCStateSpace:
520
649
  f"parameters."
521
650
  )
522
651
 
523
- if name in self._name_to_variable.keys():
652
+ if name in self._tensor_variable_info:
524
653
  raise ValueError(
525
654
  f"{name} is already a registered placeholder variable with shape "
526
- f"{self._name_to_variable[name].type.shape}"
655
+ f"{self._tensor_variable_info[name].type.shape}"
527
656
  )
528
657
 
529
658
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
530
- self._name_to_variable[name] = placeholder
659
+ tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
660
+ self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
531
661
  return placeholder
532
662
 
533
663
  def make_and_register_data(
@@ -559,14 +689,15 @@ class PyMCStateSpace:
559
689
  f"parameters."
560
690
  )
561
691
 
562
- if name in self._name_to_data.keys():
692
+ if name in self._tensor_data_info:
563
693
  raise ValueError(
564
694
  f"{name} is already a registered placeholder variable with shape "
565
- f"{self._name_to_data[name].type.shape}"
695
+ f"{self._tensor_data_info[name].type.shape}"
566
696
  )
567
697
 
568
698
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
569
- self._name_to_data[name] = placeholder
699
+ tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
700
+ self._tensor_data_info = self._tensor_data_info.add(tensor_data)
570
701
  return placeholder
571
702
 
572
703
  def make_symbolic_graph(self) -> None:
@@ -741,10 +872,8 @@ class PyMCStateSpace:
741
872
 
742
873
  Only used when models require exogenous data. The observed data is not added to the model using this method!
743
874
  """
744
-
745
- try:
746
- data_names = self.data_names
747
- except NotImplementedError:
875
+ data_names = self.data_names
876
+ if not data_names:
748
877
  return
749
878
 
750
879
  pymc_model = modelcontext(None)
@@ -892,7 +1021,6 @@ class PyMCStateSpace:
892
1021
  .. deprecated:: 0.2.5
893
1022
  The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
894
1023
  model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
895
-
896
1024
  """
897
1025
  if mode is not None:
898
1026
  warnings.warn(
@@ -1430,7 +1558,11 @@ class PyMCStateSpace:
1430
1558
  """
1431
1559
 
1432
1560
  return self._sample_conditional(
1433
- idata=idata, group="prior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1561
+ idata=idata,
1562
+ group="prior",
1563
+ random_seed=random_seed,
1564
+ mvn_method=mvn_method,
1565
+ **kwargs,
1434
1566
  )
1435
1567
 
1436
1568
  def sample_conditional_posterior(
@@ -1473,7 +1605,11 @@ class PyMCStateSpace:
1473
1605
  """
1474
1606
 
1475
1607
  return self._sample_conditional(
1476
- idata=idata, group="posterior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1608
+ idata=idata,
1609
+ group="posterior",
1610
+ random_seed=random_seed,
1611
+ mvn_method=mvn_method,
1612
+ **kwargs,
1477
1613
  )
1478
1614
 
1479
1615
  def sample_unconditional_prior(
@@ -2080,11 +2216,11 @@ class PyMCStateSpace:
2080
2216
  forecast_index: pd.RangeIndex | pd.DatetimeIndex,
2081
2217
  name=None,
2082
2218
  ):
2083
- try:
2084
- var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
2085
- except NotImplementedError:
2219
+ if not self.data_info:
2086
2220
  return scenario
2087
2221
 
2222
+ var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
2223
+
2088
2224
  if any(len(dims) > 1 for dims in var_to_dims.values()):
2089
2225
  raise NotImplementedError(">2d exogenous data is not yet supported.")
2090
2226
  coords = {
@@ -1,4 +1,3 @@
1
- import numpy as np
2
1
  import pymc as pm
3
2
  import pytensor
4
3
  import pytensor.tensor as pt
@@ -8,7 +7,9 @@ from pymc.distributions.dist_math import check_parameters
8
7
  from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9
8
  from pymc.distributions.shape_utils import get_support_shape_1d
10
9
  from pymc.logprob.abstract import _logprob
10
+ from pymc.pytensorf import normalize_rng_param
11
11
  from pytensor.graph.basic import Node
12
+ from pytensor.tensor.random import multivariate_normal
12
13
 
13
14
  floatX = pytensor.config.floatX
14
15
  COV_ZERO_TOL = 0
@@ -152,6 +153,7 @@ class _LinearGaussianStateSpace(Continuous):
152
153
  Q,
153
154
  steps,
154
155
  size=None,
156
+ rng=None,
155
157
  sequence_names=None,
156
158
  append_x0=True,
157
159
  method="svd",
@@ -178,7 +180,7 @@ class _LinearGaussianStateSpace(Continuous):
178
180
  ]
179
181
  non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
180
182
 
181
- rng = pytensor.shared(np.random.default_rng())
183
+ rng = normalize_rng_param(rng)
182
184
 
183
185
  def sort_args(args):
184
186
  sorted_args = []
@@ -367,44 +369,25 @@ class SequenceMvNormal(Continuous):
367
369
 
368
370
  @classmethod
369
371
  def dist(cls, mus, covs, logp, method="svd", **kwargs):
372
+ mus, covs, logp = map(pt.as_tensor_variable, (mus, covs, logp))
370
373
  return super().dist([mus, covs, logp], method=method, **kwargs)
371
374
 
372
375
  @classmethod
373
- def rv_op(cls, mus, covs, logp, method="svd", size=None):
374
- # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
375
- if mus.ndim > 2:
376
- mus = pt.moveaxis(mus, -2, 0)
377
- if covs.ndim > 3:
378
- covs = pt.moveaxis(covs, -3, 0)
379
-
380
- mus_, covs_ = mus.type(), covs.type()
381
-
376
+ def rv_op(cls, mus, covs, logp, method="svd", size=None, rng=None):
377
+ rng = normalize_rng_param(rng)
382
378
  logp_ = logp.type()
383
- rng = pytensor.shared(np.random.default_rng())
384
-
385
- def step(mu, cov, rng):
386
- new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
387
- return new_rng, mvn
388
-
389
- seq_mvn_rng, mvn_seq = pytensor.scan(
390
- step,
391
- sequences=[mus_, covs_],
392
- outputs_info=[rng, None],
393
- strict=True,
394
- n_steps=mus_.shape[0],
395
- return_updates=False,
396
- )
397
- mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
398
379
 
399
- # Move time axis back to position -2 so batches are on the left
400
- if mvn_seq.ndim > 2:
401
- mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
380
+ mus_, covs_ = mus.type(), covs.type()
381
+ seq_mvn_rng, mvn_seq = multivariate_normal(
382
+ mean=mus_, cov=covs_, rng=rng, method=method
383
+ ).owner.outputs
402
384
 
403
385
  mvn_seq_op = KalmanFilterRV(
404
386
  inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
405
387
  )
406
388
 
407
389
  mvn_seq = mvn_seq_op(mus, covs, logp, rng)
390
+
408
391
  return mvn_seq
409
392
 
410
393
 
@@ -8,7 +8,7 @@ from pymc.pytensorf import constant_fold
8
8
  from pytensor.graph.basic import Variable
9
9
  from pytensor.raise_op import Assert
10
10
  from pytensor.tensor import TensorVariable
11
- from pytensor.tensor.slinalg import solve_triangular
11
+ from pytensor.tensor.linalg import solve_triangular
12
12
 
13
13
  from pymc_extras.statespace.filters.utilities import (
14
14
  quad_form_sym,