pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +14 -12
  6. pymc_extras/inference/dadvi/dadvi.py +149 -128
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +196 -151
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  13. pymc_extras/inference/smc/sampling.py +2 -2
  14. pymc_extras/model/marginal/distributions.py +4 -2
  15. pymc_extras/model/marginal/graph_analysis.py +2 -2
  16. pymc_extras/model/marginal/marginal_model.py +12 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/core/statespace.py +2 -1
  21. pymc_extras/statespace/filters/distributions.py +15 -13
  22. pymc_extras/statespace/filters/kalman_filter.py +24 -22
  23. pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  24. pymc_extras/statespace/filters/utilities.py +2 -5
  25. pymc_extras/statespace/models/DFM.py +12 -27
  26. pymc_extras/statespace/models/ETS.py +190 -198
  27. pymc_extras/statespace/models/SARIMAX.py +5 -17
  28. pymc_extras/statespace/models/VARMAX.py +15 -67
  29. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  30. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  31. pymc_extras/statespace/models/utilities.py +7 -0
  32. pymc_extras/utils/model_equivalence.py +2 -2
  33. pymc_extras/utils/prior.py +10 -14
  34. pymc_extras/utils/spline.py +4 -10
  35. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
  36. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
  37. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
  38. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,12 +16,13 @@
16
16
  import collections
17
17
  import logging
18
18
  import time
19
+ import warnings
19
20
 
20
21
  from collections import Counter
21
22
  from collections.abc import Callable, Iterator
22
23
  from dataclasses import asdict, dataclass, field, replace
23
24
  from enum import Enum, auto
24
- from typing import Literal, TypeAlias
25
+ from typing import Literal, Self, TypeAlias
25
26
 
26
27
  import arviz as az
27
28
  import filelock
@@ -59,9 +60,6 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainin
59
60
  from rich.table import Table
60
61
  from rich.text import Text
61
62
 
62
- # TODO: change to typing.Self after Python versions greater than 3.10
63
- from typing_extensions import Self
64
-
65
63
  from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
66
64
  from pymc_extras.inference.pathfinder.importance_sampling import (
67
65
  importance_sampling as _importance_sampling,
@@ -280,12 +278,13 @@ def alpha_recover(
280
278
  z = pt.diff(g, axis=0)
281
279
  alpha_l_init = pt.ones(N)
282
280
 
283
- alpha, _ = pytensor.scan(
281
+ alpha = pytensor.scan(
284
282
  fn=compute_alpha_l,
285
283
  outputs_info=alpha_l_init,
286
284
  sequences=[s, z],
287
285
  n_steps=Lp1 - 1,
288
286
  allow_gc=False,
287
+ return_updates=False,
289
288
  )
290
289
 
291
290
  # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
@@ -336,11 +335,12 @@ def inverse_hessian_factors(
336
335
  return pt.set_subtensor(chi_l[j_last], diff_l)
337
336
 
338
337
  chi_init = pt.zeros((J, N))
339
- chi_mat, _ = pytensor.scan(
338
+ chi_mat = pytensor.scan(
340
339
  fn=chi_update,
341
340
  outputs_info=chi_init,
342
341
  sequences=[diff],
343
342
  allow_gc=False,
343
+ return_updates=False,
344
344
  )
345
345
 
346
346
  chi_mat = pt.matrix_transpose(chi_mat)
@@ -379,14 +379,14 @@ def inverse_hessian_factors(
379
379
  eta = pt.diagonal(E, axis1=-2, axis2=-1)
380
380
 
381
381
  # beta: (L, N, 2J)
382
- alpha_diag, _ = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha])
382
+ alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
383
383
  beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
384
384
 
385
385
  # more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
386
386
 
387
387
  # E_inv: (L, J, J)
388
388
  E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
389
- eta_diag, _ = pytensor.scan(pt.diag, sequences=[eta])
389
+ eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
390
390
 
391
391
  # block_dd: (L, J, J)
392
392
  block_dd = (
@@ -532,7 +532,9 @@ def bfgs_sample_sparse(
532
532
 
533
533
  # qr_input: (L, N, 2J)
534
534
  qr_input = inv_sqrt_alpha_diag @ beta
535
- (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
535
+ Q, R = pytensor.scan(
536
+ fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
537
+ )
536
538
 
537
539
  IdN = pt.eye(R.shape[1])[None, ...]
538
540
  IdN += IdN * REGULARISATION_TERM
@@ -625,10 +627,11 @@ def bfgs_sample(
625
627
 
626
628
  L, N, JJ = beta.shape
627
629
 
628
- (alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag), _ = pytensor.scan(
630
+ alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
629
631
  lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
630
632
  sequences=[alpha],
631
633
  allow_gc=False,
634
+ return_updates=False,
632
635
  )
633
636
 
634
637
  u = pt.random.normal(size=(L, num_samples, N))
@@ -1398,6 +1401,7 @@ def multipath_pathfinder(
1398
1401
  random_seed: RandomSeed,
1399
1402
  pathfinder_kwargs: dict = {},
1400
1403
  compile_kwargs: dict = {},
1404
+ display_summary: bool = True,
1401
1405
  ) -> MultiPathfinderResult:
1402
1406
  """
1403
1407
  Fit the Pathfinder Variational Inference algorithm using multiple paths with PyMC/PyTensor backend.
@@ -1556,8 +1560,9 @@ def multipath_pathfinder(
1556
1560
  compute_time=compute_end - compute_start,
1557
1561
  )
1558
1562
  )
1559
- # TODO: option to disable summary, save to file, etc.
1560
- mpr.display_summary()
1563
+ # Display summary conditionally
1564
+ if display_summary:
1565
+ mpr.display_summary()
1561
1566
  if mpr.all_paths_failed:
1562
1567
  raise ValueError(
1563
1568
  "All paths failed. Consider decreasing the jitter or reparameterizing the model."
@@ -1600,6 +1605,14 @@ def fit_pathfinder(
1600
1605
  pathfinder_kwargs: dict = {},
1601
1606
  compile_kwargs: dict = {},
1602
1607
  initvals: dict | None = None,
1608
+ # New pathfinder result integration options
1609
+ add_pathfinder_groups: bool = True,
1610
+ display_summary: bool | Literal["auto"] = "auto",
1611
+ store_diagnostics: bool = False,
1612
+ pathfinder_group: str = "pathfinder",
1613
+ paths_group: str = "pathfinder_paths",
1614
+ diagnostics_group: str = "pathfinder_diagnostics",
1615
+ config_group: str = "pathfinder_config",
1603
1616
  ) -> az.InferenceData:
1604
1617
  """
1605
1618
  Fit the Pathfinder Variational Inference algorithm.
@@ -1658,6 +1671,22 @@ def fit_pathfinder(
1658
1671
  initvals: dict | None = None
1659
1672
  Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1660
1673
  If None, the model's default initial values are used.
1674
+ add_pathfinder_groups : bool, optional
1675
+ Whether to add pathfinder results as additional groups to the InferenceData (default is True).
1676
+ When True, adds pathfinder and pathfinder_paths groups with optimization diagnostics.
1677
+ display_summary : bool or "auto", optional
1678
+ Whether to display the pathfinder results summary (default is "auto").
1679
+ "auto" preserves current behavior, False suppresses console output.
1680
+ store_diagnostics : bool, optional
1681
+ Whether to include potentially large diagnostic arrays in the pathfinder groups (default is False).
1682
+ pathfinder_group : str, optional
1683
+ Name for the main pathfinder results group (default is "pathfinder").
1684
+ paths_group : str, optional
1685
+ Name for the per-path results group (default is "pathfinder_paths").
1686
+ diagnostics_group : str, optional
1687
+ Name for the diagnostics group (default is "pathfinder_diagnostics").
1688
+ config_group : str, optional
1689
+ Name for the configuration group (default is "pathfinder_config").
1661
1690
 
1662
1691
  Returns
1663
1692
  -------
@@ -1694,6 +1723,9 @@ def fit_pathfinder(
1694
1723
  maxcor = np.ceil(3 * np.log(N)).astype(np.int32)
1695
1724
  maxcor = max(maxcor, 5)
1696
1725
 
1726
+ # Handle display_summary logic
1727
+ should_display_summary = display_summary == "auto" or display_summary is True
1728
+
1697
1729
  if inference_backend == "pymc":
1698
1730
  mp_result = multipath_pathfinder(
1699
1731
  model,
@@ -1714,6 +1746,7 @@ def fit_pathfinder(
1714
1746
  random_seed=random_seed,
1715
1747
  pathfinder_kwargs=pathfinder_kwargs,
1716
1748
  compile_kwargs=compile_kwargs,
1749
+ display_summary=should_display_summary,
1717
1750
  )
1718
1751
  pathfinder_samples = mp_result.samples
1719
1752
  elif inference_backend == "blackjax":
@@ -1760,4 +1793,30 @@ def fit_pathfinder(
1760
1793
 
1761
1794
  idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
1762
1795
 
1796
+ # Add pathfinder results to InferenceData if requested
1797
+ if add_pathfinder_groups:
1798
+ if inference_backend == "pymc":
1799
+ from pymc_extras.inference.pathfinder.idata import add_pathfinder_to_inference_data
1800
+
1801
+ idata = add_pathfinder_to_inference_data(
1802
+ idata=idata,
1803
+ result=mp_result,
1804
+ model=model,
1805
+ group=pathfinder_group,
1806
+ paths_group=paths_group,
1807
+ diagnostics_group=diagnostics_group,
1808
+ config_group=config_group,
1809
+ store_diagnostics=store_diagnostics,
1810
+ )
1811
+ else:
1812
+ warnings.warn(
1813
+ f"Pathfinder diagnostic groups are only supported with the PyMC backend. "
1814
+ f"Current backend is '{inference_backend}', which does not support adding "
1815
+ "pathfinder diagnostics to InferenceData. The InferenceData will only contain "
1816
+ "posterior samples. To add diagnostic groups, use inference_backend='pymc', "
1817
+ "or set add_pathfinder_groups=False to suppress this warning.",
1818
+ UserWarning,
1819
+ stacklevel=2,
1820
+ )
1821
+
1763
1822
  return idata
@@ -238,7 +238,7 @@ class SMCDiagnostics(NamedTuple):
238
238
  def update_diagnosis(i, history, info, state):
239
239
  le, lli, ancestors, weights_evolution = history
240
240
  return SMCDiagnostics(
241
- le.at[i].set(state.lmbda),
241
+ le.at[i].set(state.tempering_param),
242
242
  lli.at[i].set(info.log_likelihood_increment),
243
243
  ancestors.at[i].set(info.ancestors),
244
244
  weights_evolution.at[i].set(state.weights),
@@ -265,7 +265,7 @@ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_par
265
265
 
266
266
  def cond(carry):
267
267
  i, state, _, _ = carry
268
- return state.lmbda < 1
268
+ return state.tempering_param < 1
269
269
 
270
270
  def one_step(carry):
271
271
  i, state, k, previous_info = carry
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
282
282
  def logp_fn(marginalized_rv_const, *non_sequences):
283
283
  return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
284
284
 
285
- joint_logps, _ = scan_map(
285
+ joint_logps = scan_map(
286
286
  fn=logp_fn,
287
287
  sequences=marginalized_rv_domain_tensor,
288
288
  non_sequences=[*values, *inputs],
289
289
  mode=Mode().including("local_remove_check_parameter"),
290
+ return_updates=False,
290
291
  )
291
292
 
292
293
  joint_logp = pt.logsumexp(joint_logps, axis=0)
@@ -350,12 +351,13 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
350
351
 
351
352
  P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
352
353
  log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
353
- log_alpha_seq, _ = scan(
354
+ log_alpha_seq = scan(
354
355
  step_alpha,
355
356
  non_sequences=[log_P],
356
357
  outputs_info=[log_alpha_init],
357
358
  # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
358
359
  sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
360
+ return_updates=False,
359
361
  )
360
362
  # Final logp is just the sum of the last scan state
361
363
  joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
@@ -6,8 +6,8 @@ from itertools import zip_longest
6
6
  from pymc import SymbolicRandomVariable
7
7
  from pymc.model.fgraph import ModelVar
8
8
  from pymc.variational.minibatch_rv import MinibatchRandomVariable
9
- from pytensor.graph import Variable, ancestors
10
- from pytensor.graph.basic import io_toposort
9
+ from pytensor.graph.basic import Variable
10
+ from pytensor.graph.traversal import ancestors, io_toposort
11
11
  from pytensor.tensor import TensorType, TensorVariable
12
12
  from pytensor.tensor.blockwise import Blockwise
13
13
  from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -11,7 +11,7 @@ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_po
11
11
  from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
12
12
  from pymc.distributions.transforms import Chain
13
13
  from pymc.logprob.transforms import IntervalTransform
14
- from pymc.model import Model
14
+ from pymc.model import Model, modelcontext
15
15
  from pymc.model.fgraph import (
16
16
  ModelFreeRV,
17
17
  ModelValuedVar,
@@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts):
337
337
 
338
338
 
339
339
  def recover_marginals(
340
- model: Model,
341
340
  idata: InferenceData,
341
+ *,
342
+ model: Model | None = None,
342
343
  var_names: Sequence[str] | None = None,
343
344
  return_samples: bool = True,
344
345
  extend_inferencedata: bool = True,
@@ -389,6 +390,15 @@ def recover_marginals(
389
390
 
390
391
 
391
392
  """
393
+ # Temporary error message for helping with migration
394
+ # Will be removed in a future release
395
+ if isinstance(idata, Model):
396
+ raise TypeError(
397
+ "The order of arguments of `recover_marginals` changed. The first input must be an idata"
398
+ )
399
+
400
+ model = modelcontext(model)
401
+
392
402
  unmarginal_model = unmarginalize(model)
393
403
 
394
404
  # Find the names of the marginalized variables
@@ -334,7 +334,9 @@ class ModelBuilder:
334
334
  >>> model = MyModel(ModelBuilder)
335
335
  >>> idata = az.InferenceData(your_dataset)
336
336
  >>> model.set_idata_attrs(idata=idata)
337
- >>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
337
+ >>> assert (
338
+ ... "id" in idata.attrs
339
+ ... ) # this and the following lines are part of doctest, not user manual
338
340
  >>> assert "model_type" in idata.attrs
339
341
  >>> assert "version" in idata.attrs
340
342
  >>> assert "sampler_config" in idata.attrs
@@ -382,7 +384,7 @@ class ModelBuilder:
382
384
  >>> super().__init__()
383
385
  >>> model = MyModel()
384
386
  >>> model.fit(data)
385
- >>> model.save('model_results.nc') # This will call the overridden method in MyModel
387
+ >>> model.save("model_results.nc") # This will call the overridden method in MyModel
386
388
  """
387
389
  if self.idata is not None and "posterior" in self.idata:
388
390
  file = Path(str(fname))
@@ -432,7 +434,7 @@ class ModelBuilder:
432
434
  --------
433
435
  >>> class MyModel(ModelBuilder):
434
436
  >>> ...
435
- >>> name = './mymodel.nc'
437
+ >>> name = "./mymodel.nc"
436
438
  >>> imported_model = MyModel.load(name)
437
439
  """
438
440
  filepath = Path(str(fname))
@@ -444,6 +446,7 @@ class ModelBuilder:
444
446
  sampler_config=json.loads(idata.attrs["sampler_config"]),
445
447
  )
446
448
  model.idata = idata
449
+ model.is_fitted_ = True
447
450
  dataset = idata.fit_data.to_dataframe()
448
451
  X = dataset.drop(columns=[model.output_var])
449
452
  y = dataset[model.output_var]
@@ -524,6 +527,8 @@ class ModelBuilder:
524
527
  )
525
528
  self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
526
529
 
530
+ self.is_fitted_ = True
531
+
527
532
  return self.idata # type: ignore
528
533
 
529
534
  def predict(
@@ -554,7 +559,7 @@ class ModelBuilder:
554
559
  >>> model = MyModel()
555
560
  >>> idata = model.fit(data)
556
561
  >>> x_pred = []
557
- >>> prediction_data = pd.DataFrame({'input':x_pred})
562
+ >>> prediction_data = pd.DataFrame({"input": x_pred})
558
563
  >>> pred_mean = model.predict(prediction_data)
559
564
  """
560
565
 
pymc_extras/prior.py CHANGED
@@ -70,8 +70,10 @@ Create a prior with a custom transform function by registering it with
70
70
 
71
71
  from pymc_extras.prior import register_tensor_transform
72
72
 
73
+
73
74
  def custom_transform(x):
74
- return x ** 2
75
+ return x**2
76
+
75
77
 
76
78
  register_tensor_transform("square", custom_transform)
77
79
 
@@ -138,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
138
140
 
139
141
  Doesn't check for validity of the dims
140
142
 
143
+ Parameters
144
+ ----------
145
+ x : pt.TensorLike
146
+ The tensor to align.
147
+ dims : Dims
148
+ The current dimensions of the tensor.
149
+ desired_dims : Dims
150
+ The desired dimensions of the tensor.
151
+
152
+ Returns
153
+ -------
154
+ pt.TensorVariable
155
+ The aligned tensor.
156
+
141
157
  Examples
142
158
  --------
143
- 1D to 2D with new dim
159
+ Handle transpose 1D to 2D with new dimension.
144
160
 
145
161
  .. code-block:: python
146
162
 
@@ -177,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
177
193
 
178
194
 
179
195
  DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
196
+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
180
197
 
181
198
 
182
199
  def create_dim_handler(desired_dims: Dims) -> DimHandler:
183
- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
200
+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
201
+
202
+ Parameters
203
+ ----------
204
+ desired_dims : Dims
205
+ The desired dimensions to align to.
206
+
207
+ Returns
208
+ -------
209
+ DimHandler
210
+ A function that takes a tensor and its current dims and aligns it to
211
+ the desired dims.
212
+
213
+
214
+ Examples
215
+ --------
216
+ Create a dim handler to align to ("channel", "group").
217
+
218
+ .. code-block:: python
219
+
220
+ import numpy as np
221
+
222
+ from pymc_extras.prior import create_dim_handler
223
+
224
+ dim_handler = create_dim_handler(("channel", "group"))
225
+
226
+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
227
+
228
+
229
+ """
184
230
 
185
231
  def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
186
232
  return handle_dims(x, dims, desired_dims)
@@ -228,8 +274,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None:
228
274
  register_tensor_transform,
229
275
  )
230
276
 
277
+
231
278
  def custom_transform(x):
232
- return x ** 2
279
+ return x**2
280
+
233
281
 
234
282
  register_tensor_transform("square", custom_transform)
235
283
 
@@ -268,9 +316,50 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
268
316
 
269
317
  @runtime_checkable
270
318
  class VariableFactory(Protocol):
271
- """Protocol for something that works like a Prior class."""
319
+ '''Protocol for something that works like a Prior class.
320
+
321
+ Sample with :func:`sample_prior`.
322
+
323
+ Examples
324
+ --------
325
+ Create a custom variable factory.
326
+
327
+ .. code-block:: python
328
+
329
+ import pymc as pm
330
+
331
+ import pytensor.tensor as pt
332
+
333
+ from pymc_extras.prior import sample_prior, VariableFactory
334
+
335
+
336
+ class PowerSumDistribution:
337
+ """Create a distribution that is the sum of powers of a base distribution."""
338
+
339
+ def __init__(self, distribution: VariableFactory, n: int):
340
+ self.distribution = distribution
341
+ self.n = n
342
+
343
+ @property
344
+ def dims(self):
345
+ return self.distribution.dims
346
+
347
+ def create_variable(self, name: str) -> "TensorVariable":
348
+ raw = self.distribution.create_variable(f"{name}_raw")
349
+ return pm.Deterministic(
350
+ name,
351
+ pt.sum([raw**n for n in range(1, self.n + 1)], axis=0),
352
+ dims=self.dims,
353
+ )
354
+
355
+
356
+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
357
+ samples = sample_prior(cubic)
358
+
359
+ '''
272
360
 
273
361
  dims: tuple[str, ...]
362
+ """The dimensions of the variable to create."""
274
363
 
275
364
  def create_variable(self, name: str) -> pt.TensorVariable:
276
365
  """Create a TensorVariable."""
@@ -316,6 +405,7 @@ def sample_prior(
316
405
 
317
406
  from pymc_extras.prior import sample_prior
318
407
 
408
+
319
409
  class CustomVariableDefinition:
320
410
  def __init__(self, dims, n: int):
321
411
  self.dims = dims
@@ -323,7 +413,8 @@ def sample_prior(
323
413
 
324
414
  def create_variable(self, name: str) -> "TensorVariable":
325
415
  x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
326
- return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
416
+ return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
417
+
327
418
 
328
419
  cubic = CustomVariableDefinition(dims=("channel",), n=3)
329
420
  coords = {"channel": ["C1", "C2", "C3"]}
@@ -381,6 +472,82 @@ class Prior:
381
472
  be registered with `register_tensor_transform` function or
382
473
  be available in either `pytensor.tensor` or `pymc.math`.
383
474
 
475
+ Examples
476
+ --------
477
+ Create a normal prior.
478
+
479
+ .. code-block:: python
480
+
481
+ from pymc_extras.prior import Prior
482
+
483
+ normal = Prior("Normal")
484
+
485
+ Create a hierarchical normal prior by using distributions for the parameters
486
+ and specifying the dims.
487
+
488
+ .. code-block:: python
489
+
490
+ hierarchical_normal = Prior(
491
+ "Normal",
492
+ mu=Prior("Normal"),
493
+ sigma=Prior("HalfNormal"),
494
+ dims="channel",
495
+ )
496
+
497
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
498
+
499
+ .. code-block:: python
500
+
501
+ non_centered_hierarchical_normal = Prior(
502
+ "Normal",
503
+ mu=Prior("Normal"),
504
+ sigma=Prior("HalfNormal"),
505
+ dims="channel",
506
+ # Only change needed to make it non-centered
507
+ centered=False,
508
+ )
509
+
510
+ Create a hierarchical beta prior by using Beta distribution, distributions for
511
+ the parameters, and specifying the dims.
512
+
513
+ .. code-block:: python
514
+
515
+ hierarchical_beta = Prior(
516
+ "Beta",
517
+ alpha=Prior("HalfNormal"),
518
+ beta=Prior("HalfNormal"),
519
+ dims="channel",
520
+ )
521
+
522
+ Create a transformed hierarchical normal prior by using the `transform`
523
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
524
+
525
+ .. code-block:: python
526
+
527
+ transformed_hierarchical_normal = Prior(
528
+ "Normal",
529
+ mu=Prior("Normal"),
530
+ sigma=Prior("HalfNormal"),
531
+ transform="sigmoid",
532
+ dims="channel",
533
+ )
534
+
535
+ Create a prior with a custom transform function by registering it with
536
+ :func:`register_tensor_transform`.
537
+
538
+ .. code-block:: python
539
+
540
+ from pymc_extras.prior import register_tensor_transform
541
+
542
+
543
+ def custom_transform(x):
544
+ return x**2
545
+
546
+
547
+ register_tensor_transform("square", custom_transform)
548
+
549
+ custom_distribution = Prior("Normal", transform="square")
550
+
384
551
  """
385
552
 
386
553
  # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -389,9 +556,13 @@ class Prior:
389
556
  "StudentT": {"mu": 0, "sigma": 1},
390
557
  "ZeroSumNormal": {"sigma": 1},
391
558
  }
559
+ """Available non-centered distributions and their default parameters."""
392
560
 
393
561
  pymc_distribution: type[pm.Distribution]
562
+ """The PyMC distribution class."""
563
+
394
564
  pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
565
+ """The PyTensor transform function."""
395
566
 
396
567
  @validate_call
397
568
  def __init__(
@@ -1317,9 +1488,33 @@ class Censored:
1317
1488
 
1318
1489
 
1319
1490
  class Scaled:
1320
- """Scaled distribution for numerical stability."""
1491
+ """Scaled distribution for numerical stability.
1492
+
1493
+ This is the same as multiplying the variable by a constant factor.
1494
+
1495
+ Parameters
1496
+ ----------
1497
+ dist : Prior
1498
+ The prior distribution to scale.
1499
+ factor : pt.TensorLike
1500
+ The scaling factor. This will have to be broadcastable to the
1501
+ dimensions of the distribution.
1502
+
1503
+ Examples
1504
+ --------
1505
+ Create a scaled normal distribution.
1506
+
1507
+ .. code-block:: python
1508
+
1509
+ from pymc_extras.prior import Prior, Scaled
1510
+
1511
+ normal = Prior("Normal", mu=0, sigma=1)
1512
+ # Same as Normal(mu=0, sigma=10)
1513
+ scaled_normal = Scaled(normal, factor=10)
1514
+
1515
+ """
1321
1516
 
1322
- def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
1517
+ def __init__(self, dist: Prior, factor: pt.TensorLike) -> None:
1323
1518
  self.dist = dist
1324
1519
  self.factor = factor
1325
1520
 
@@ -28,7 +28,7 @@ def compile_statespace(
28
28
  x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
29
29
  )
30
30
 
31
- inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
31
+ inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))
32
32
 
33
33
  _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
34
 
@@ -2500,13 +2500,14 @@ class PyMCStateSpace:
2500
2500
  next_x = c + T @ x + R @ shock
2501
2501
  return next_x
2502
2502
 
2503
- irf, updates = pytensor.scan(
2503
+ irf = pytensor.scan(
2504
2504
  irf_step,
2505
2505
  sequences=[shock_trajectory],
2506
2506
  outputs_info=[x0],
2507
2507
  non_sequences=[c, T, R],
2508
2508
  n_steps=n_steps,
2509
2509
  strict=True,
2510
+ return_updates=False,
2510
2511
  )
2511
2512
 
2512
2513
  pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])