pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/__init__.py +8 -1
  7. pymc_extras/inference/dadvi/__init__.py +0 -0
  8. pymc_extras/inference/dadvi/dadvi.py +351 -0
  9. pymc_extras/inference/fit.py +5 -0
  10. pymc_extras/inference/laplace_approx/find_map.py +32 -47
  11. pymc_extras/inference/laplace_approx/idata.py +27 -6
  12. pymc_extras/inference/laplace_approx/laplace.py +24 -6
  13. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  14. pymc_extras/inference/pathfinder/idata.py +517 -0
  15. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  16. pymc_extras/model/marginal/graph_analysis.py +2 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  21. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  22. pymc_extras/statespace/filters/utilities.py +2 -5
  23. pymc_extras/statespace/models/DFM.py +834 -0
  24. pymc_extras/statespace/models/ETS.py +190 -198
  25. pymc_extras/statespace/models/SARIMAX.py +9 -21
  26. pymc_extras/statespace/models/VARMAX.py +22 -74
  27. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  28. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  29. pymc_extras/statespace/models/utilities.py +7 -0
  30. pymc_extras/statespace/utils/constants.py +3 -1
  31. pymc_extras/utils/model_equivalence.py +2 -2
  32. pymc_extras/utils/prior.py +10 -14
  33. pymc_extras/utils/spline.py +4 -10
  34. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  35. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
  36. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  37. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -334,7 +334,9 @@ class ModelBuilder:
334
334
  >>> model = MyModel(ModelBuilder)
335
335
  >>> idata = az.InferenceData(your_dataset)
336
336
  >>> model.set_idata_attrs(idata=idata)
337
- >>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
337
+ >>> assert (
338
+ ... "id" in idata.attrs
339
+ ... ) # this and the following lines are part of doctest, not user manual
338
340
  >>> assert "model_type" in idata.attrs
339
341
  >>> assert "version" in idata.attrs
340
342
  >>> assert "sampler_config" in idata.attrs
@@ -382,7 +384,7 @@ class ModelBuilder:
382
384
  >>> super().__init__()
383
385
  >>> model = MyModel()
384
386
  >>> model.fit(data)
385
- >>> model.save('model_results.nc') # This will call the overridden method in MyModel
387
+ >>> model.save("model_results.nc") # This will call the overridden method in MyModel
386
388
  """
387
389
  if self.idata is not None and "posterior" in self.idata:
388
390
  file = Path(str(fname))
@@ -432,7 +434,7 @@ class ModelBuilder:
432
434
  --------
433
435
  >>> class MyModel(ModelBuilder):
434
436
  >>> ...
435
- >>> name = './mymodel.nc'
437
+ >>> name = "./mymodel.nc"
436
438
  >>> imported_model = MyModel.load(name)
437
439
  """
438
440
  filepath = Path(str(fname))
@@ -444,6 +446,7 @@ class ModelBuilder:
444
446
  sampler_config=json.loads(idata.attrs["sampler_config"]),
445
447
  )
446
448
  model.idata = idata
449
+ model.is_fitted_ = True
447
450
  dataset = idata.fit_data.to_dataframe()
448
451
  X = dataset.drop(columns=[model.output_var])
449
452
  y = dataset[model.output_var]
@@ -524,6 +527,8 @@ class ModelBuilder:
524
527
  )
525
528
  self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
526
529
 
530
+ self.is_fitted_ = True
531
+
527
532
  return self.idata # type: ignore
528
533
 
529
534
  def predict(
@@ -554,7 +559,7 @@ class ModelBuilder:
554
559
  >>> model = MyModel()
555
560
  >>> idata = model.fit(data)
556
561
  >>> x_pred = []
557
- >>> prediction_data = pd.DataFrame({'input':x_pred})
562
+ >>> prediction_data = pd.DataFrame({"input": x_pred})
558
563
  >>> pred_mean = model.predict(prediction_data)
559
564
  """
560
565
 
pymc_extras/prior.py CHANGED
@@ -70,8 +70,10 @@ Create a prior with a custom transform function by registering it with
70
70
 
71
71
  from pymc_extras.prior import register_tensor_transform
72
72
 
73
+
73
74
  def custom_transform(x):
74
- return x ** 2
75
+ return x**2
76
+
75
77
 
76
78
  register_tensor_transform("square", custom_transform)
77
79
 
@@ -138,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
138
140
 
139
141
  Doesn't check for validity of the dims
140
142
 
143
+ Parameters
144
+ ----------
145
+ x : pt.TensorLike
146
+ The tensor to align.
147
+ dims : Dims
148
+ The current dimensions of the tensor.
149
+ desired_dims : Dims
150
+ The desired dimensions of the tensor.
151
+
152
+ Returns
153
+ -------
154
+ pt.TensorVariable
155
+ The aligned tensor.
156
+
141
157
  Examples
142
158
  --------
143
- 1D to 2D with new dim
159
+ Handle transpose 1D to 2D with new dimension.
144
160
 
145
161
  .. code-block:: python
146
162
 
@@ -177,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
177
193
 
178
194
 
179
195
  DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
196
+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
180
197
 
181
198
 
182
199
  def create_dim_handler(desired_dims: Dims) -> DimHandler:
183
- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
200
+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
201
+
202
+ Parameters
203
+ ----------
204
+ desired_dims : Dims
205
+ The desired dimensions to align to.
206
+
207
+ Returns
208
+ -------
209
+ DimHandler
210
+ A function that takes a tensor and its current dims and aligns it to
211
+ the desired dims.
212
+
213
+
214
+ Examples
215
+ --------
216
+ Create a dim handler to align to ("channel", "group").
217
+
218
+ .. code-block:: python
219
+
220
+ import numpy as np
221
+
222
+ from pymc_extras.prior import create_dim_handler
223
+
224
+ dim_handler = create_dim_handler(("channel", "group"))
225
+
226
+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
227
+
228
+
229
+ """
184
230
 
185
231
  def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
186
232
  return handle_dims(x, dims, desired_dims)
@@ -228,8 +274,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None:
228
274
  register_tensor_transform,
229
275
  )
230
276
 
277
+
231
278
  def custom_transform(x):
232
- return x ** 2
279
+ return x**2
280
+
233
281
 
234
282
  register_tensor_transform("square", custom_transform)
235
283
 
@@ -268,9 +316,50 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
268
316
 
269
317
  @runtime_checkable
270
318
  class VariableFactory(Protocol):
271
- """Protocol for something that works like a Prior class."""
319
+ '''Protocol for something that works like a Prior class.
320
+
321
+ Sample with :func:`sample_prior`.
322
+
323
+ Examples
324
+ --------
325
+ Create a custom variable factory.
326
+
327
+ .. code-block:: python
328
+
329
+ import pymc as pm
330
+
331
+ import pytensor.tensor as pt
332
+
333
+ from pymc_extras.prior import sample_prior, VariableFactory
334
+
335
+
336
+ class PowerSumDistribution:
337
+ """Create a distribution that is the sum of powers of a base distribution."""
338
+
339
+ def __init__(self, distribution: VariableFactory, n: int):
340
+ self.distribution = distribution
341
+ self.n = n
342
+
343
+ @property
344
+ def dims(self):
345
+ return self.distribution.dims
346
+
347
+ def create_variable(self, name: str) -> "TensorVariable":
348
+ raw = self.distribution.create_variable(f"{name}_raw")
349
+ return pm.Deterministic(
350
+ name,
351
+ pt.sum([raw**n for n in range(1, self.n + 1)], axis=0),
352
+ dims=self.dims,
353
+ )
354
+
355
+
356
+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
357
+ samples = sample_prior(cubic)
358
+
359
+ '''
272
360
 
273
361
  dims: tuple[str, ...]
362
+ """The dimensions of the variable to create."""
274
363
 
275
364
  def create_variable(self, name: str) -> pt.TensorVariable:
276
365
  """Create a TensorVariable."""
@@ -316,6 +405,7 @@ def sample_prior(
316
405
 
317
406
  from pymc_extras.prior import sample_prior
318
407
 
408
+
319
409
  class CustomVariableDefinition:
320
410
  def __init__(self, dims, n: int):
321
411
  self.dims = dims
@@ -323,7 +413,8 @@ def sample_prior(
323
413
 
324
414
  def create_variable(self, name: str) -> "TensorVariable":
325
415
  x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
326
- return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
416
+ return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
417
+
327
418
 
328
419
  cubic = CustomVariableDefinition(dims=("channel",), n=3)
329
420
  coords = {"channel": ["C1", "C2", "C3"]}
@@ -381,6 +472,82 @@ class Prior:
381
472
  be registered with `register_tensor_transform` function or
382
473
  be available in either `pytensor.tensor` or `pymc.math`.
383
474
 
475
+ Examples
476
+ --------
477
+ Create a normal prior.
478
+
479
+ .. code-block:: python
480
+
481
+ from pymc_extras.prior import Prior
482
+
483
+ normal = Prior("Normal")
484
+
485
+ Create a hierarchical normal prior by using distributions for the parameters
486
+ and specifying the dims.
487
+
488
+ .. code-block:: python
489
+
490
+ hierarchical_normal = Prior(
491
+ "Normal",
492
+ mu=Prior("Normal"),
493
+ sigma=Prior("HalfNormal"),
494
+ dims="channel",
495
+ )
496
+
497
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
498
+
499
+ .. code-block:: python
500
+
501
+ non_centered_hierarchical_normal = Prior(
502
+ "Normal",
503
+ mu=Prior("Normal"),
504
+ sigma=Prior("HalfNormal"),
505
+ dims="channel",
506
+ # Only change needed to make it non-centered
507
+ centered=False,
508
+ )
509
+
510
+ Create a hierarchical beta prior by using Beta distribution, distributions for
511
+ the parameters, and specifying the dims.
512
+
513
+ .. code-block:: python
514
+
515
+ hierarchical_beta = Prior(
516
+ "Beta",
517
+ alpha=Prior("HalfNormal"),
518
+ beta=Prior("HalfNormal"),
519
+ dims="channel",
520
+ )
521
+
522
+ Create a transformed hierarchical normal prior by using the `transform`
523
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
524
+
525
+ .. code-block:: python
526
+
527
+ transformed_hierarchical_normal = Prior(
528
+ "Normal",
529
+ mu=Prior("Normal"),
530
+ sigma=Prior("HalfNormal"),
531
+ transform="sigmoid",
532
+ dims="channel",
533
+ )
534
+
535
+ Create a prior with a custom transform function by registering it with
536
+ :func:`register_tensor_transform`.
537
+
538
+ .. code-block:: python
539
+
540
+ from pymc_extras.prior import register_tensor_transform
541
+
542
+
543
+ def custom_transform(x):
544
+ return x**2
545
+
546
+
547
+ register_tensor_transform("square", custom_transform)
548
+
549
+ custom_distribution = Prior("Normal", transform="square")
550
+
384
551
  """
385
552
 
386
553
  # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -389,9 +556,13 @@ class Prior:
389
556
  "StudentT": {"mu": 0, "sigma": 1},
390
557
  "ZeroSumNormal": {"sigma": 1},
391
558
  }
559
+ """Available non-centered distributions and their default parameters."""
392
560
 
393
561
  pymc_distribution: type[pm.Distribution]
562
+ """The PyMC distribution class."""
563
+
394
564
  pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
565
+ """The PyTensor transform function."""
395
566
 
396
567
  @validate_call
397
568
  def __init__(
@@ -1317,9 +1488,33 @@ class Censored:
1317
1488
 
1318
1489
 
1319
1490
  class Scaled:
1320
- """Scaled distribution for numerical stability."""
1491
+ """Scaled distribution for numerical stability.
1492
+
1493
+ This is the same as multiplying the variable by a constant factor.
1494
+
1495
+ Parameters
1496
+ ----------
1497
+ dist : Prior
1498
+ The prior distribution to scale.
1499
+ factor : pt.TensorLike
1500
+ The scaling factor. This will have to be broadcastable to the
1501
+ dimensions of the distribution.
1502
+
1503
+ Examples
1504
+ --------
1505
+ Create a scaled normal distribution.
1506
+
1507
+ .. code-block:: python
1508
+
1509
+ from pymc_extras.prior import Prior, Scaled
1510
+
1511
+ normal = Prior("Normal", mu=0, sigma=1)
1512
+ # Same as Normal(mu=0, sigma=10)
1513
+ scaled_normal = Scaled(normal, factor=10)
1514
+
1515
+ """
1321
1516
 
1322
- def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
1517
+ def __init__(self, dist: Prior, factor: pt.TensorLike) -> None:
1323
1518
  self.dist = dist
1324
1519
  self.factor = factor
1325
1520
 
@@ -28,7 +28,7 @@ def compile_statespace(
28
28
  x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
29
29
  )
30
30
 
31
- inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
31
+ inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))
32
32
 
33
33
  _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
34
 
@@ -200,7 +200,7 @@ class BaseFilter(ABC):
200
200
  self.n_endog = Z_shape[-2]
201
201
 
202
202
  data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
203
-
203
+ data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
204
204
  sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
205
205
  params, PARAM_NAMES
206
206
  )
@@ -393,7 +393,7 @@ class BaseFilter(ABC):
393
393
  .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
394
394
  2nd ed, Oxford University Press, 2012.
395
395
  """
396
- a_hat = T.dot(a) + c
396
+ a_hat = T @ a + c
397
397
  P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
398
398
 
399
399
  return a_hat, P_hat
@@ -580,16 +580,16 @@ class StandardFilter(BaseFilter):
580
580
  .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
581
581
  2nd ed, Oxford University Press, 2012.
582
582
  """
583
- y_hat = d + Z.dot(a)
583
+ y_hat = d + Z @ a
584
584
  v = y - y_hat
585
585
 
586
- PZT = P.dot(Z.T)
586
+ PZT = P.dot(Z.mT)
587
587
  F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
588
588
 
589
- K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
589
+ K = pt.linalg.solve(F.mT, PZT.mT, assume_a="pos", check_finite=False).mT
590
590
  I_KZ = pt.eye(self.n_states) - K.dot(Z)
591
591
 
592
- a_filtered = a + K.dot(v)
592
+ a_filtered = a + K @ v
593
593
  P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
594
594
 
595
595
  F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
@@ -630,9 +630,9 @@ class SquareRootFilter(BaseFilter):
630
630
  a_hat = T.dot(a) + c
631
631
  Q_chol = pt.linalg.cholesky(Q, lower=True)
632
632
 
633
- M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
633
+ M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).mT
634
634
  R_decomp = pt.linalg.qr(M, mode="r")
635
- P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
635
+ P_chol_hat = R_decomp[..., : self.n_states, : self.n_states].mT
636
636
 
637
637
  return a_hat, P_chol_hat
638
638
 
@@ -658,14 +658,14 @@ class SquareRootFilter(BaseFilter):
658
658
  # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
659
659
  # [0, L_pred]]
660
660
  # The Schur decomposition of this matrix will be B (upper triangular). We are
661
- # more insterested in B^T:
661
+ # more interested in B^T:
662
662
  # Structure of B^T = [[chol(F), 0 ],
663
663
  # [K @ chol(F), chol(P_filtered)]
664
664
  zeros = pt.zeros((self.n_states, self.n_endog))
665
665
  upper = pt.horizontal_stack(H_chol, Z @ P_chol)
666
666
  lower = pt.horizontal_stack(zeros, P_chol)
667
667
  A_T = pt.vertical_stack(upper, lower)
668
- B = pt.linalg.qr(A_T.T, mode="r").T
668
+ B = pt.linalg.qr(A_T.mT, mode="r").mT
669
669
 
670
670
  F_chol = B[: self.n_endog, : self.n_endog]
671
671
  K_F_chol = B[self.n_endog :, : self.n_endog]
@@ -677,6 +677,7 @@ class SquareRootFilter(BaseFilter):
677
677
  inner_term = solve_triangular(
678
678
  F_chol, solve_triangular(F_chol, v, lower=True), lower=True
679
679
  )
680
+
680
681
  loss = (v.T @ inner_term).ravel()
681
682
 
682
683
  # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
@@ -800,7 +801,7 @@ class UnivariateFilter(BaseFilter):
800
801
  obs_cov[-1],
801
802
  )
802
803
 
803
- P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter)
804
+ P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter)
804
805
  a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
805
806
 
806
807
  ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
@@ -1,8 +1,6 @@
1
1
  import pytensor
2
2
  import pytensor.tensor as pt
3
3
 
4
- from pytensor.tensor.nlinalg import matrix_dot
5
-
6
4
  from pymc_extras.statespace.filters.utilities import (
7
5
  quad_form_sym,
8
6
  split_vars_into_seq_and_nonseq,
@@ -105,7 +103,7 @@ class KalmanSmoother:
105
103
  a_hat, P_hat = self.predict(a, P, T, R, Q)
106
104
 
107
105
  # Use pinv, otherwise P_hat is singular when there is missing data
108
- smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
106
+ smoother_gain = (pt.linalg.pinv(P_hat, hermitian=True) @ T @ P).mT
109
107
  a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110
108
 
111
109
  P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
@@ -1,7 +1,5 @@
1
1
  import pytensor.tensor as pt
2
2
 
3
- from pytensor.tensor.nlinalg import matrix_dot
4
-
5
3
  from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
6
4
 
7
5
 
@@ -48,12 +46,11 @@ def split_vars_into_seq_and_nonseq(params, param_names):
48
46
 
49
47
 
50
48
  def stabilize(cov, jitter=JITTER_DEFAULT):
51
- # Ensure diagonal is non-zero
52
49
  cov = cov + pt.identity_like(cov) * jitter
53
50
 
54
51
  return cov
55
52
 
56
53
 
57
54
  def quad_form_sym(A, B):
58
- out = matrix_dot(A, B, A.T)
59
- return 0.5 * (out + out.T)
55
+ out = A @ B @ A.mT
56
+ return 0.5 * (out + out.mT)