pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +14 -12
- pymc_extras/inference/dadvi/dadvi.py +149 -128
- pymc_extras/inference/laplace_approx/find_map.py +16 -39
- pymc_extras/inference/laplace_approx/idata.py +22 -4
- pymc_extras/inference/laplace_approx/laplace.py +196 -151
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +71 -12
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/core/statespace.py +2 -1
- pymc_extras/statespace/filters/distributions.py +15 -13
- pymc_extras/statespace/filters/kalman_filter.py +24 -22
- pymc_extras/statespace/filters/kalman_smoother.py +3 -5
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +12 -27
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +5 -17
- pymc_extras/statespace/models/VARMAX.py +15 -67
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
import collections
|
|
17
17
|
import logging
|
|
18
18
|
import time
|
|
19
|
+
import warnings
|
|
19
20
|
|
|
20
21
|
from collections import Counter
|
|
21
22
|
from collections.abc import Callable, Iterator
|
|
22
23
|
from dataclasses import asdict, dataclass, field, replace
|
|
23
24
|
from enum import Enum, auto
|
|
24
|
-
from typing import Literal, TypeAlias
|
|
25
|
+
from typing import Literal, Self, TypeAlias
|
|
25
26
|
|
|
26
27
|
import arviz as az
|
|
27
28
|
import filelock
|
|
@@ -59,9 +60,6 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainin
|
|
|
59
60
|
from rich.table import Table
|
|
60
61
|
from rich.text import Text
|
|
61
62
|
|
|
62
|
-
# TODO: change to typing.Self after Python versions greater than 3.10
|
|
63
|
-
from typing_extensions import Self
|
|
64
|
-
|
|
65
63
|
from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
|
|
66
64
|
from pymc_extras.inference.pathfinder.importance_sampling import (
|
|
67
65
|
importance_sampling as _importance_sampling,
|
|
@@ -280,12 +278,13 @@ def alpha_recover(
|
|
|
280
278
|
z = pt.diff(g, axis=0)
|
|
281
279
|
alpha_l_init = pt.ones(N)
|
|
282
280
|
|
|
283
|
-
alpha
|
|
281
|
+
alpha = pytensor.scan(
|
|
284
282
|
fn=compute_alpha_l,
|
|
285
283
|
outputs_info=alpha_l_init,
|
|
286
284
|
sequences=[s, z],
|
|
287
285
|
n_steps=Lp1 - 1,
|
|
288
286
|
allow_gc=False,
|
|
287
|
+
return_updates=False,
|
|
289
288
|
)
|
|
290
289
|
|
|
291
290
|
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
|
|
@@ -336,11 +335,12 @@ def inverse_hessian_factors(
|
|
|
336
335
|
return pt.set_subtensor(chi_l[j_last], diff_l)
|
|
337
336
|
|
|
338
337
|
chi_init = pt.zeros((J, N))
|
|
339
|
-
chi_mat
|
|
338
|
+
chi_mat = pytensor.scan(
|
|
340
339
|
fn=chi_update,
|
|
341
340
|
outputs_info=chi_init,
|
|
342
341
|
sequences=[diff],
|
|
343
342
|
allow_gc=False,
|
|
343
|
+
return_updates=False,
|
|
344
344
|
)
|
|
345
345
|
|
|
346
346
|
chi_mat = pt.matrix_transpose(chi_mat)
|
|
@@ -379,14 +379,14 @@ def inverse_hessian_factors(
|
|
|
379
379
|
eta = pt.diagonal(E, axis1=-2, axis2=-1)
|
|
380
380
|
|
|
381
381
|
# beta: (L, N, 2J)
|
|
382
|
-
alpha_diag
|
|
382
|
+
alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
|
|
383
383
|
beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
|
|
384
384
|
|
|
385
385
|
# more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
|
|
386
386
|
|
|
387
387
|
# E_inv: (L, J, J)
|
|
388
388
|
E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
|
|
389
|
-
eta_diag
|
|
389
|
+
eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
|
|
390
390
|
|
|
391
391
|
# block_dd: (L, J, J)
|
|
392
392
|
block_dd = (
|
|
@@ -532,7 +532,9 @@ def bfgs_sample_sparse(
|
|
|
532
532
|
|
|
533
533
|
# qr_input: (L, N, 2J)
|
|
534
534
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
535
|
-
|
|
535
|
+
Q, R = pytensor.scan(
|
|
536
|
+
fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
|
|
537
|
+
)
|
|
536
538
|
|
|
537
539
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
538
540
|
IdN += IdN * REGULARISATION_TERM
|
|
@@ -625,10 +627,11 @@ def bfgs_sample(
|
|
|
625
627
|
|
|
626
628
|
L, N, JJ = beta.shape
|
|
627
629
|
|
|
628
|
-
|
|
630
|
+
alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
|
|
629
631
|
lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
|
|
630
632
|
sequences=[alpha],
|
|
631
633
|
allow_gc=False,
|
|
634
|
+
return_updates=False,
|
|
632
635
|
)
|
|
633
636
|
|
|
634
637
|
u = pt.random.normal(size=(L, num_samples, N))
|
|
@@ -1398,6 +1401,7 @@ def multipath_pathfinder(
|
|
|
1398
1401
|
random_seed: RandomSeed,
|
|
1399
1402
|
pathfinder_kwargs: dict = {},
|
|
1400
1403
|
compile_kwargs: dict = {},
|
|
1404
|
+
display_summary: bool = True,
|
|
1401
1405
|
) -> MultiPathfinderResult:
|
|
1402
1406
|
"""
|
|
1403
1407
|
Fit the Pathfinder Variational Inference algorithm using multiple paths with PyMC/PyTensor backend.
|
|
@@ -1556,8 +1560,9 @@ def multipath_pathfinder(
|
|
|
1556
1560
|
compute_time=compute_end - compute_start,
|
|
1557
1561
|
)
|
|
1558
1562
|
)
|
|
1559
|
-
#
|
|
1560
|
-
|
|
1563
|
+
# Display summary conditionally
|
|
1564
|
+
if display_summary:
|
|
1565
|
+
mpr.display_summary()
|
|
1561
1566
|
if mpr.all_paths_failed:
|
|
1562
1567
|
raise ValueError(
|
|
1563
1568
|
"All paths failed. Consider decreasing the jitter or reparameterizing the model."
|
|
@@ -1600,6 +1605,14 @@ def fit_pathfinder(
|
|
|
1600
1605
|
pathfinder_kwargs: dict = {},
|
|
1601
1606
|
compile_kwargs: dict = {},
|
|
1602
1607
|
initvals: dict | None = None,
|
|
1608
|
+
# New pathfinder result integration options
|
|
1609
|
+
add_pathfinder_groups: bool = True,
|
|
1610
|
+
display_summary: bool | Literal["auto"] = "auto",
|
|
1611
|
+
store_diagnostics: bool = False,
|
|
1612
|
+
pathfinder_group: str = "pathfinder",
|
|
1613
|
+
paths_group: str = "pathfinder_paths",
|
|
1614
|
+
diagnostics_group: str = "pathfinder_diagnostics",
|
|
1615
|
+
config_group: str = "pathfinder_config",
|
|
1603
1616
|
) -> az.InferenceData:
|
|
1604
1617
|
"""
|
|
1605
1618
|
Fit the Pathfinder Variational Inference algorithm.
|
|
@@ -1658,6 +1671,22 @@ def fit_pathfinder(
|
|
|
1658
1671
|
initvals: dict | None = None
|
|
1659
1672
|
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
1660
1673
|
If None, the model's default initial values are used.
|
|
1674
|
+
add_pathfinder_groups : bool, optional
|
|
1675
|
+
Whether to add pathfinder results as additional groups to the InferenceData (default is True).
|
|
1676
|
+
When True, adds pathfinder and pathfinder_paths groups with optimization diagnostics.
|
|
1677
|
+
display_summary : bool or "auto", optional
|
|
1678
|
+
Whether to display the pathfinder results summary (default is "auto").
|
|
1679
|
+
"auto" preserves current behavior, False suppresses console output.
|
|
1680
|
+
store_diagnostics : bool, optional
|
|
1681
|
+
Whether to include potentially large diagnostic arrays in the pathfinder groups (default is False).
|
|
1682
|
+
pathfinder_group : str, optional
|
|
1683
|
+
Name for the main pathfinder results group (default is "pathfinder").
|
|
1684
|
+
paths_group : str, optional
|
|
1685
|
+
Name for the per-path results group (default is "pathfinder_paths").
|
|
1686
|
+
diagnostics_group : str, optional
|
|
1687
|
+
Name for the diagnostics group (default is "pathfinder_diagnostics").
|
|
1688
|
+
config_group : str, optional
|
|
1689
|
+
Name for the configuration group (default is "pathfinder_config").
|
|
1661
1690
|
|
|
1662
1691
|
Returns
|
|
1663
1692
|
-------
|
|
@@ -1694,6 +1723,9 @@ def fit_pathfinder(
|
|
|
1694
1723
|
maxcor = np.ceil(3 * np.log(N)).astype(np.int32)
|
|
1695
1724
|
maxcor = max(maxcor, 5)
|
|
1696
1725
|
|
|
1726
|
+
# Handle display_summary logic
|
|
1727
|
+
should_display_summary = display_summary == "auto" or display_summary is True
|
|
1728
|
+
|
|
1697
1729
|
if inference_backend == "pymc":
|
|
1698
1730
|
mp_result = multipath_pathfinder(
|
|
1699
1731
|
model,
|
|
@@ -1714,6 +1746,7 @@ def fit_pathfinder(
|
|
|
1714
1746
|
random_seed=random_seed,
|
|
1715
1747
|
pathfinder_kwargs=pathfinder_kwargs,
|
|
1716
1748
|
compile_kwargs=compile_kwargs,
|
|
1749
|
+
display_summary=should_display_summary,
|
|
1717
1750
|
)
|
|
1718
1751
|
pathfinder_samples = mp_result.samples
|
|
1719
1752
|
elif inference_backend == "blackjax":
|
|
@@ -1760,4 +1793,30 @@ def fit_pathfinder(
|
|
|
1760
1793
|
|
|
1761
1794
|
idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
|
|
1762
1795
|
|
|
1796
|
+
# Add pathfinder results to InferenceData if requested
|
|
1797
|
+
if add_pathfinder_groups:
|
|
1798
|
+
if inference_backend == "pymc":
|
|
1799
|
+
from pymc_extras.inference.pathfinder.idata import add_pathfinder_to_inference_data
|
|
1800
|
+
|
|
1801
|
+
idata = add_pathfinder_to_inference_data(
|
|
1802
|
+
idata=idata,
|
|
1803
|
+
result=mp_result,
|
|
1804
|
+
model=model,
|
|
1805
|
+
group=pathfinder_group,
|
|
1806
|
+
paths_group=paths_group,
|
|
1807
|
+
diagnostics_group=diagnostics_group,
|
|
1808
|
+
config_group=config_group,
|
|
1809
|
+
store_diagnostics=store_diagnostics,
|
|
1810
|
+
)
|
|
1811
|
+
else:
|
|
1812
|
+
warnings.warn(
|
|
1813
|
+
f"Pathfinder diagnostic groups are only supported with the PyMC backend. "
|
|
1814
|
+
f"Current backend is '{inference_backend}', which does not support adding "
|
|
1815
|
+
"pathfinder diagnostics to InferenceData. The InferenceData will only contain "
|
|
1816
|
+
"posterior samples. To add diagnostic groups, use inference_backend='pymc', "
|
|
1817
|
+
"or set add_pathfinder_groups=False to suppress this warning.",
|
|
1818
|
+
UserWarning,
|
|
1819
|
+
stacklevel=2,
|
|
1820
|
+
)
|
|
1821
|
+
|
|
1763
1822
|
return idata
|
|
@@ -238,7 +238,7 @@ class SMCDiagnostics(NamedTuple):
|
|
|
238
238
|
def update_diagnosis(i, history, info, state):
|
|
239
239
|
le, lli, ancestors, weights_evolution = history
|
|
240
240
|
return SMCDiagnostics(
|
|
241
|
-
le.at[i].set(state.
|
|
241
|
+
le.at[i].set(state.tempering_param),
|
|
242
242
|
lli.at[i].set(info.log_likelihood_increment),
|
|
243
243
|
ancestors.at[i].set(info.ancestors),
|
|
244
244
|
weights_evolution.at[i].set(state.weights),
|
|
@@ -265,7 +265,7 @@ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_par
|
|
|
265
265
|
|
|
266
266
|
def cond(carry):
|
|
267
267
|
i, state, _, _ = carry
|
|
268
|
-
return state.
|
|
268
|
+
return state.tempering_param < 1
|
|
269
269
|
|
|
270
270
|
def one_step(carry):
|
|
271
271
|
i, state, k, previous_info = carry
|
|
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
|
|
|
282
282
|
def logp_fn(marginalized_rv_const, *non_sequences):
|
|
283
283
|
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
|
|
284
284
|
|
|
285
|
-
joint_logps
|
|
285
|
+
joint_logps = scan_map(
|
|
286
286
|
fn=logp_fn,
|
|
287
287
|
sequences=marginalized_rv_domain_tensor,
|
|
288
288
|
non_sequences=[*values, *inputs],
|
|
289
289
|
mode=Mode().including("local_remove_check_parameter"),
|
|
290
|
+
return_updates=False,
|
|
290
291
|
)
|
|
291
292
|
|
|
292
293
|
joint_logp = pt.logsumexp(joint_logps, axis=0)
|
|
@@ -350,12 +351,13 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
|
|
350
351
|
|
|
351
352
|
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
|
|
352
353
|
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
|
|
353
|
-
log_alpha_seq
|
|
354
|
+
log_alpha_seq = scan(
|
|
354
355
|
step_alpha,
|
|
355
356
|
non_sequences=[log_P],
|
|
356
357
|
outputs_info=[log_alpha_init],
|
|
357
358
|
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
|
|
358
359
|
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
|
|
360
|
+
return_updates=False,
|
|
359
361
|
)
|
|
360
362
|
# Final logp is just the sum of the last scan state
|
|
361
363
|
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
|
|
@@ -6,8 +6,8 @@ from itertools import zip_longest
|
|
|
6
6
|
from pymc import SymbolicRandomVariable
|
|
7
7
|
from pymc.model.fgraph import ModelVar
|
|
8
8
|
from pymc.variational.minibatch_rv import MinibatchRandomVariable
|
|
9
|
-
from pytensor.graph import Variable
|
|
10
|
-
from pytensor.graph.
|
|
9
|
+
from pytensor.graph.basic import Variable
|
|
10
|
+
from pytensor.graph.traversal import ancestors, io_toposort
|
|
11
11
|
from pytensor.tensor import TensorType, TensorVariable
|
|
12
12
|
from pytensor.tensor.blockwise import Blockwise
|
|
13
13
|
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
@@ -11,7 +11,7 @@ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_po
|
|
|
11
11
|
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
|
12
12
|
from pymc.distributions.transforms import Chain
|
|
13
13
|
from pymc.logprob.transforms import IntervalTransform
|
|
14
|
-
from pymc.model import Model
|
|
14
|
+
from pymc.model import Model, modelcontext
|
|
15
15
|
from pymc.model.fgraph import (
|
|
16
16
|
ModelFreeRV,
|
|
17
17
|
ModelValuedVar,
|
|
@@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts):
|
|
|
337
337
|
|
|
338
338
|
|
|
339
339
|
def recover_marginals(
|
|
340
|
-
model: Model,
|
|
341
340
|
idata: InferenceData,
|
|
341
|
+
*,
|
|
342
|
+
model: Model | None = None,
|
|
342
343
|
var_names: Sequence[str] | None = None,
|
|
343
344
|
return_samples: bool = True,
|
|
344
345
|
extend_inferencedata: bool = True,
|
|
@@ -389,6 +390,15 @@ def recover_marginals(
|
|
|
389
390
|
|
|
390
391
|
|
|
391
392
|
"""
|
|
393
|
+
# Temporary error message for helping with migration
|
|
394
|
+
# Will be removed in a future release
|
|
395
|
+
if isinstance(idata, Model):
|
|
396
|
+
raise TypeError(
|
|
397
|
+
"The order of arguments of `recover_marginals` changed. The first input must be an idata"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
model = modelcontext(model)
|
|
401
|
+
|
|
392
402
|
unmarginal_model = unmarginalize(model)
|
|
393
403
|
|
|
394
404
|
# Find the names of the marginalized variables
|
pymc_extras/model_builder.py
CHANGED
|
@@ -334,7 +334,9 @@ class ModelBuilder:
|
|
|
334
334
|
>>> model = MyModel(ModelBuilder)
|
|
335
335
|
>>> idata = az.InferenceData(your_dataset)
|
|
336
336
|
>>> model.set_idata_attrs(idata=idata)
|
|
337
|
-
>>> assert
|
|
337
|
+
>>> assert (
|
|
338
|
+
... "id" in idata.attrs
|
|
339
|
+
... ) # this and the following lines are part of doctest, not user manual
|
|
338
340
|
>>> assert "model_type" in idata.attrs
|
|
339
341
|
>>> assert "version" in idata.attrs
|
|
340
342
|
>>> assert "sampler_config" in idata.attrs
|
|
@@ -382,7 +384,7 @@ class ModelBuilder:
|
|
|
382
384
|
>>> super().__init__()
|
|
383
385
|
>>> model = MyModel()
|
|
384
386
|
>>> model.fit(data)
|
|
385
|
-
>>> model.save(
|
|
387
|
+
>>> model.save("model_results.nc") # This will call the overridden method in MyModel
|
|
386
388
|
"""
|
|
387
389
|
if self.idata is not None and "posterior" in self.idata:
|
|
388
390
|
file = Path(str(fname))
|
|
@@ -432,7 +434,7 @@ class ModelBuilder:
|
|
|
432
434
|
--------
|
|
433
435
|
>>> class MyModel(ModelBuilder):
|
|
434
436
|
>>> ...
|
|
435
|
-
>>> name =
|
|
437
|
+
>>> name = "./mymodel.nc"
|
|
436
438
|
>>> imported_model = MyModel.load(name)
|
|
437
439
|
"""
|
|
438
440
|
filepath = Path(str(fname))
|
|
@@ -444,6 +446,7 @@ class ModelBuilder:
|
|
|
444
446
|
sampler_config=json.loads(idata.attrs["sampler_config"]),
|
|
445
447
|
)
|
|
446
448
|
model.idata = idata
|
|
449
|
+
model.is_fitted_ = True
|
|
447
450
|
dataset = idata.fit_data.to_dataframe()
|
|
448
451
|
X = dataset.drop(columns=[model.output_var])
|
|
449
452
|
y = dataset[model.output_var]
|
|
@@ -524,6 +527,8 @@ class ModelBuilder:
|
|
|
524
527
|
)
|
|
525
528
|
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
|
|
526
529
|
|
|
530
|
+
self.is_fitted_ = True
|
|
531
|
+
|
|
527
532
|
return self.idata # type: ignore
|
|
528
533
|
|
|
529
534
|
def predict(
|
|
@@ -554,7 +559,7 @@ class ModelBuilder:
|
|
|
554
559
|
>>> model = MyModel()
|
|
555
560
|
>>> idata = model.fit(data)
|
|
556
561
|
>>> x_pred = []
|
|
557
|
-
>>> prediction_data = pd.DataFrame({
|
|
562
|
+
>>> prediction_data = pd.DataFrame({"input": x_pred})
|
|
558
563
|
>>> pred_mean = model.predict(prediction_data)
|
|
559
564
|
"""
|
|
560
565
|
|
pymc_extras/prior.py
CHANGED
|
@@ -70,8 +70,10 @@ Create a prior with a custom transform function by registering it with
|
|
|
70
70
|
|
|
71
71
|
from pymc_extras.prior import register_tensor_transform
|
|
72
72
|
|
|
73
|
+
|
|
73
74
|
def custom_transform(x):
|
|
74
|
-
return x
|
|
75
|
+
return x**2
|
|
76
|
+
|
|
75
77
|
|
|
76
78
|
register_tensor_transform("square", custom_transform)
|
|
77
79
|
|
|
@@ -138,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
|
|
|
138
140
|
|
|
139
141
|
Doesn't check for validity of the dims
|
|
140
142
|
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
x : pt.TensorLike
|
|
146
|
+
The tensor to align.
|
|
147
|
+
dims : Dims
|
|
148
|
+
The current dimensions of the tensor.
|
|
149
|
+
desired_dims : Dims
|
|
150
|
+
The desired dimensions of the tensor.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
pt.TensorVariable
|
|
155
|
+
The aligned tensor.
|
|
156
|
+
|
|
141
157
|
Examples
|
|
142
158
|
--------
|
|
143
|
-
1D to 2D with new
|
|
159
|
+
Handle transpose 1D to 2D with new dimension.
|
|
144
160
|
|
|
145
161
|
.. code-block:: python
|
|
146
162
|
|
|
@@ -177,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
|
|
|
177
193
|
|
|
178
194
|
|
|
179
195
|
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
|
|
196
|
+
"""A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
|
|
180
197
|
|
|
181
198
|
|
|
182
199
|
def create_dim_handler(desired_dims: Dims) -> DimHandler:
|
|
183
|
-
"""Wrap the
|
|
200
|
+
"""Wrap the :func:`handle_dims` function to always use the same desired_dims.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
desired_dims : Dims
|
|
205
|
+
The desired dimensions to align to.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
DimHandler
|
|
210
|
+
A function that takes a tensor and its current dims and aligns it to
|
|
211
|
+
the desired dims.
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
Examples
|
|
215
|
+
--------
|
|
216
|
+
Create a dim handler to align to ("channel", "group").
|
|
217
|
+
|
|
218
|
+
.. code-block:: python
|
|
219
|
+
|
|
220
|
+
import numpy as np
|
|
221
|
+
|
|
222
|
+
from pymc_extras.prior import create_dim_handler
|
|
223
|
+
|
|
224
|
+
dim_handler = create_dim_handler(("channel", "group"))
|
|
225
|
+
|
|
226
|
+
result = dim_handler(np.array([1, 2, 3]), dims="channel")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
"""
|
|
184
230
|
|
|
185
231
|
def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
|
|
186
232
|
return handle_dims(x, dims, desired_dims)
|
|
@@ -228,8 +274,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None:
|
|
|
228
274
|
register_tensor_transform,
|
|
229
275
|
)
|
|
230
276
|
|
|
277
|
+
|
|
231
278
|
def custom_transform(x):
|
|
232
|
-
return x
|
|
279
|
+
return x**2
|
|
280
|
+
|
|
233
281
|
|
|
234
282
|
register_tensor_transform("square", custom_transform)
|
|
235
283
|
|
|
@@ -268,9 +316,50 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
|
|
|
268
316
|
|
|
269
317
|
@runtime_checkable
|
|
270
318
|
class VariableFactory(Protocol):
|
|
271
|
-
|
|
319
|
+
'''Protocol for something that works like a Prior class.
|
|
320
|
+
|
|
321
|
+
Sample with :func:`sample_prior`.
|
|
322
|
+
|
|
323
|
+
Examples
|
|
324
|
+
--------
|
|
325
|
+
Create a custom variable factory.
|
|
326
|
+
|
|
327
|
+
.. code-block:: python
|
|
328
|
+
|
|
329
|
+
import pymc as pm
|
|
330
|
+
|
|
331
|
+
import pytensor.tensor as pt
|
|
332
|
+
|
|
333
|
+
from pymc_extras.prior import sample_prior, VariableFactory
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class PowerSumDistribution:
|
|
337
|
+
"""Create a distribution that is the sum of powers of a base distribution."""
|
|
338
|
+
|
|
339
|
+
def __init__(self, distribution: VariableFactory, n: int):
|
|
340
|
+
self.distribution = distribution
|
|
341
|
+
self.n = n
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def dims(self):
|
|
345
|
+
return self.distribution.dims
|
|
346
|
+
|
|
347
|
+
def create_variable(self, name: str) -> "TensorVariable":
|
|
348
|
+
raw = self.distribution.create_variable(f"{name}_raw")
|
|
349
|
+
return pm.Deterministic(
|
|
350
|
+
name,
|
|
351
|
+
pt.sum([raw**n for n in range(1, self.n + 1)], axis=0),
|
|
352
|
+
dims=self.dims,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
cubic = PowerSumDistribution(Prior("Normal"), n=3)
|
|
357
|
+
samples = sample_prior(cubic)
|
|
358
|
+
|
|
359
|
+
'''
|
|
272
360
|
|
|
273
361
|
dims: tuple[str, ...]
|
|
362
|
+
"""The dimensions of the variable to create."""
|
|
274
363
|
|
|
275
364
|
def create_variable(self, name: str) -> pt.TensorVariable:
|
|
276
365
|
"""Create a TensorVariable."""
|
|
@@ -316,6 +405,7 @@ def sample_prior(
|
|
|
316
405
|
|
|
317
406
|
from pymc_extras.prior import sample_prior
|
|
318
407
|
|
|
408
|
+
|
|
319
409
|
class CustomVariableDefinition:
|
|
320
410
|
def __init__(self, dims, n: int):
|
|
321
411
|
self.dims = dims
|
|
@@ -323,7 +413,8 @@ def sample_prior(
|
|
|
323
413
|
|
|
324
414
|
def create_variable(self, name: str) -> "TensorVariable":
|
|
325
415
|
x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
|
|
326
|
-
return pt.sum([x
|
|
416
|
+
return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
|
|
417
|
+
|
|
327
418
|
|
|
328
419
|
cubic = CustomVariableDefinition(dims=("channel",), n=3)
|
|
329
420
|
coords = {"channel": ["C1", "C2", "C3"]}
|
|
@@ -381,6 +472,82 @@ class Prior:
|
|
|
381
472
|
be registered with `register_tensor_transform` function or
|
|
382
473
|
be available in either `pytensor.tensor` or `pymc.math`.
|
|
383
474
|
|
|
475
|
+
Examples
|
|
476
|
+
--------
|
|
477
|
+
Create a normal prior.
|
|
478
|
+
|
|
479
|
+
.. code-block:: python
|
|
480
|
+
|
|
481
|
+
from pymc_extras.prior import Prior
|
|
482
|
+
|
|
483
|
+
normal = Prior("Normal")
|
|
484
|
+
|
|
485
|
+
Create a hierarchical normal prior by using distributions for the parameters
|
|
486
|
+
and specifying the dims.
|
|
487
|
+
|
|
488
|
+
.. code-block:: python
|
|
489
|
+
|
|
490
|
+
hierarchical_normal = Prior(
|
|
491
|
+
"Normal",
|
|
492
|
+
mu=Prior("Normal"),
|
|
493
|
+
sigma=Prior("HalfNormal"),
|
|
494
|
+
dims="channel",
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
Create a non-centered hierarchical normal prior with the `centered` parameter.
|
|
498
|
+
|
|
499
|
+
.. code-block:: python
|
|
500
|
+
|
|
501
|
+
non_centered_hierarchical_normal = Prior(
|
|
502
|
+
"Normal",
|
|
503
|
+
mu=Prior("Normal"),
|
|
504
|
+
sigma=Prior("HalfNormal"),
|
|
505
|
+
dims="channel",
|
|
506
|
+
# Only change needed to make it non-centered
|
|
507
|
+
centered=False,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
Create a hierarchical beta prior by using Beta distribution, distributions for
|
|
511
|
+
the parameters, and specifying the dims.
|
|
512
|
+
|
|
513
|
+
.. code-block:: python
|
|
514
|
+
|
|
515
|
+
hierarchical_beta = Prior(
|
|
516
|
+
"Beta",
|
|
517
|
+
alpha=Prior("HalfNormal"),
|
|
518
|
+
beta=Prior("HalfNormal"),
|
|
519
|
+
dims="channel",
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
Create a transformed hierarchical normal prior by using the `transform`
|
|
523
|
+
parameter. Here the "sigmoid" transformation comes from `pm.math`.
|
|
524
|
+
|
|
525
|
+
.. code-block:: python
|
|
526
|
+
|
|
527
|
+
transformed_hierarchical_normal = Prior(
|
|
528
|
+
"Normal",
|
|
529
|
+
mu=Prior("Normal"),
|
|
530
|
+
sigma=Prior("HalfNormal"),
|
|
531
|
+
transform="sigmoid",
|
|
532
|
+
dims="channel",
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
Create a prior with a custom transform function by registering it with
|
|
536
|
+
:func:`register_tensor_transform`.
|
|
537
|
+
|
|
538
|
+
.. code-block:: python
|
|
539
|
+
|
|
540
|
+
from pymc_extras.prior import register_tensor_transform
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def custom_transform(x):
|
|
544
|
+
return x**2
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
register_tensor_transform("square", custom_transform)
|
|
548
|
+
|
|
549
|
+
custom_distribution = Prior("Normal", transform="square")
|
|
550
|
+
|
|
384
551
|
"""
|
|
385
552
|
|
|
386
553
|
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
|
|
@@ -389,9 +556,13 @@ class Prior:
|
|
|
389
556
|
"StudentT": {"mu": 0, "sigma": 1},
|
|
390
557
|
"ZeroSumNormal": {"sigma": 1},
|
|
391
558
|
}
|
|
559
|
+
"""Available non-centered distributions and their default parameters."""
|
|
392
560
|
|
|
393
561
|
pymc_distribution: type[pm.Distribution]
|
|
562
|
+
"""The PyMC distribution class."""
|
|
563
|
+
|
|
394
564
|
pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
|
|
565
|
+
"""The PyTensor transform function."""
|
|
395
566
|
|
|
396
567
|
@validate_call
|
|
397
568
|
def __init__(
|
|
@@ -1317,9 +1488,33 @@ class Censored:
|
|
|
1317
1488
|
|
|
1318
1489
|
|
|
1319
1490
|
class Scaled:
|
|
1320
|
-
"""Scaled distribution for numerical stability.
|
|
1491
|
+
"""Scaled distribution for numerical stability.
|
|
1492
|
+
|
|
1493
|
+
This is the same as multiplying the variable by a constant factor.
|
|
1494
|
+
|
|
1495
|
+
Parameters
|
|
1496
|
+
----------
|
|
1497
|
+
dist : Prior
|
|
1498
|
+
The prior distribution to scale.
|
|
1499
|
+
factor : pt.TensorLike
|
|
1500
|
+
The scaling factor. This will have to be broadcastable to the
|
|
1501
|
+
dimensions of the distribution.
|
|
1502
|
+
|
|
1503
|
+
Examples
|
|
1504
|
+
--------
|
|
1505
|
+
Create a scaled normal distribution.
|
|
1506
|
+
|
|
1507
|
+
.. code-block:: python
|
|
1508
|
+
|
|
1509
|
+
from pymc_extras.prior import Prior, Scaled
|
|
1510
|
+
|
|
1511
|
+
normal = Prior("Normal", mu=0, sigma=1)
|
|
1512
|
+
# Same as Normal(mu=0, sigma=10)
|
|
1513
|
+
scaled_normal = Scaled(normal, factor=10)
|
|
1514
|
+
|
|
1515
|
+
"""
|
|
1321
1516
|
|
|
1322
|
-
def __init__(self, dist: Prior, factor:
|
|
1517
|
+
def __init__(self, dist: Prior, factor: pt.TensorLike) -> None:
|
|
1323
1518
|
self.dist = dist
|
|
1324
1519
|
self.factor = factor
|
|
1325
1520
|
|
|
@@ -28,7 +28,7 @@ def compile_statespace(
|
|
|
28
28
|
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
-
inputs = list(pytensor.graph.
|
|
31
|
+
inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))
|
|
32
32
|
|
|
33
33
|
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
|
|
34
34
|
|
|
@@ -2500,13 +2500,14 @@ class PyMCStateSpace:
|
|
|
2500
2500
|
next_x = c + T @ x + R @ shock
|
|
2501
2501
|
return next_x
|
|
2502
2502
|
|
|
2503
|
-
irf
|
|
2503
|
+
irf = pytensor.scan(
|
|
2504
2504
|
irf_step,
|
|
2505
2505
|
sequences=[shock_trajectory],
|
|
2506
2506
|
outputs_info=[x0],
|
|
2507
2507
|
non_sequences=[c, T, R],
|
|
2508
2508
|
n_steps=n_steps,
|
|
2509
2509
|
strict=True,
|
|
2510
|
+
return_updates=False,
|
|
2510
2511
|
)
|
|
2511
2512
|
|
|
2512
2513
|
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
|