pymc-extras 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymc_extras/prior.py ADDED
@@ -0,0 +1,1388 @@
1
+ """Class that represents a prior distribution.
2
+
3
+ The `Prior` class is a wrapper around PyMC distributions that allows the user
4
+ to create outside of the PyMC model.
5
+
6
+ Examples
7
+ --------
8
+ Create a normal prior.
9
+
10
+ .. code-block:: python
11
+
12
+ from pymc_extras.prior import Prior
13
+
14
+ normal = Prior("Normal")
15
+
16
+ Create a hierarchical normal prior by using distributions for the parameters
17
+ and specifying the dims.
18
+
19
+ .. code-block:: python
20
+
21
+ hierarchical_normal = Prior(
22
+ "Normal",
23
+ mu=Prior("Normal"),
24
+ sigma=Prior("HalfNormal"),
25
+ dims="channel",
26
+ )
27
+
28
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
29
+
30
+ .. code-block:: python
31
+
32
+ non_centered_hierarchical_normal = Prior(
33
+ "Normal",
34
+ mu=Prior("Normal"),
35
+ sigma=Prior("HalfNormal"),
36
+ dims="channel",
37
+ # Only change needed to make it non-centered
38
+ centered=False,
39
+ )
40
+
41
+ Create a hierarchical beta prior by using Beta distribution, distributions for
42
+ the parameters, and specifying the dims.
43
+
44
+ .. code-block:: python
45
+
46
+ hierarchical_beta = Prior(
47
+ "Beta",
48
+ alpha=Prior("HalfNormal"),
49
+ beta=Prior("HalfNormal"),
50
+ dims="channel",
51
+ )
52
+
53
+ Create a transformed hierarchical normal prior by using the `transform`
54
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
55
+
56
+ .. code-block:: python
57
+
58
+ transformed_hierarchical_normal = Prior(
59
+ "Normal",
60
+ mu=Prior("Normal"),
61
+ sigma=Prior("HalfNormal"),
62
+ transform="sigmoid",
63
+ dims="channel",
64
+ )
65
+
66
+ Create a prior with a custom transform function by registering it with
67
+ `register_tensor_transform`.
68
+
69
+ .. code-block:: python
70
+
71
+ from pymc_extras.prior import register_tensor_transform
72
+
73
+ def custom_transform(x):
74
+ return x ** 2
75
+
76
+ register_tensor_transform("square", custom_transform)
77
+
78
+ custom_distribution = Prior("Normal", transform="square")
79
+
80
+ """
81
+
82
+ from __future__ import annotations
83
+
84
+ import copy
85
+
86
+ from collections.abc import Callable
87
+ from functools import partial
88
+ from inspect import signature
89
+ from typing import Any, Protocol, runtime_checkable
90
+
91
+ import numpy as np
92
+ import pymc as pm
93
+ import pytensor.tensor as pt
94
+ import xarray as xr
95
+
96
+ from pydantic import InstanceOf, validate_call
97
+ from pydantic.dataclasses import dataclass
98
+ from pymc.distributions.shape_utils import Dims
99
+
100
+ from pymc_extras.deserialize import deserialize, register_deserialization
101
+
102
+
103
+ class UnsupportedShapeError(Exception):
104
+ """Error for when the shapes from variables are not compatible."""
105
+
106
+
107
+ class UnsupportedDistributionError(Exception):
108
+ """Error for when an unsupported distribution is used."""
109
+
110
+
111
+ class UnsupportedParameterizationError(Exception):
112
+ """The follow parameterization is not supported."""
113
+
114
+
115
+ class MuAlreadyExistsError(Exception):
116
+ """Error for when 'mu' is present in Prior."""
117
+
118
+ def __init__(self, distribution: Prior) -> None:
119
+ self.distribution = distribution
120
+ self.message = f"The mu parameter is already defined in {distribution}"
121
+ super().__init__(self.message)
122
+
123
+
124
+ class UnknownTransformError(Exception):
125
+ """Error for when an unknown transform is used."""
126
+
127
+
128
+ def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
129
+ """Remove leading 'x' from the args."""
130
+ while args and args[0] == "x":
131
+ args.pop(0)
132
+
133
+ return args
134
+
135
+
136
+ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
137
+ """Take a tensor of dims `dims` and align it to `desired_dims`.
138
+
139
+ Doesn't check for validity of the dims
140
+
141
+ Examples
142
+ --------
143
+ 1D to 2D with new dim
144
+
145
+ .. code-block:: python
146
+
147
+ x = np.array([1, 2, 3])
148
+ dims = "channel"
149
+
150
+ desired_dims = ("channel", "group")
151
+
152
+ handle_dims(x, dims, desired_dims)
153
+
154
+ """
155
+ x = pt.as_tensor_variable(x)
156
+
157
+ if np.ndim(x) == 0:
158
+ return x
159
+
160
+ dims = dims if isinstance(dims, tuple) else (dims,)
161
+ desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)
162
+
163
+ if difference := set(dims).difference(desired_dims):
164
+ raise UnsupportedShapeError(
165
+ f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
166
+ f"{difference} is missing from the desired dims."
167
+ )
168
+
169
+ aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)
170
+
171
+ missing_dims = aligned_dims.sum(axis=0) == 0
172
+ new_idx = aligned_dims.argmax(axis=0)
173
+
174
+ args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)]
175
+ args = _remove_leading_xs(args)
176
+ return x.dimshuffle(*args)
177
+
178
+
179
+ DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
180
+
181
+
182
+ def create_dim_handler(desired_dims: Dims) -> DimHandler:
183
+ """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
184
+
185
+ def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
186
+ return handle_dims(x, dims, desired_dims)
187
+
188
+ return func
189
+
190
+
191
+ def _dims_to_str(obj: tuple[str, ...]) -> str:
192
+ if len(obj) == 1:
193
+ return f'"{obj[0]}"'
194
+
195
+ return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")"
196
+
197
+
198
+ def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
199
+ if not hasattr(pm, name):
200
+ raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}")
201
+
202
+ return getattr(pm, name)
203
+
204
+
205
+ Transform = Callable[[pt.TensorLike], pt.TensorLike]
206
+
207
+ CUSTOM_TRANSFORMS: dict[str, Transform] = {}
208
+
209
+
210
+ def register_tensor_transform(name: str, transform: Transform) -> None:
211
+ """Register a tensor transform function to be used in the `Prior` class.
212
+
213
+ Parameters
214
+ ----------
215
+ name : str
216
+ The name of the transform.
217
+ func : Callable[[pt.TensorLike], pt.TensorLike]
218
+ The function to apply to the tensor.
219
+
220
+ Examples
221
+ --------
222
+ Register a custom transform function.
223
+
224
+ .. code-block:: python
225
+
226
+ from pymc_extras.prior import (
227
+ Prior,
228
+ register_tensor_transform,
229
+ )
230
+
231
+ def custom_transform(x):
232
+ return x ** 2
233
+
234
+ register_tensor_transform("square", custom_transform)
235
+
236
+ custom_distribution = Prior("Normal", transform="square")
237
+
238
+ """
239
+ CUSTOM_TRANSFORMS[name] = transform
240
+
241
+
242
+ def _get_transform(name: str):
243
+ if name in CUSTOM_TRANSFORMS:
244
+ return CUSTOM_TRANSFORMS[name]
245
+
246
+ for module in (pt, pm.math):
247
+ if hasattr(module, name):
248
+ break
249
+ else:
250
+ module = None
251
+
252
+ if not module:
253
+ msg = (
254
+ f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
255
+ "If this is a custom function, register it with the "
256
+ "`pymc_extras.prior.register_tensor_transform` function before "
257
+ "previous function call."
258
+ )
259
+
260
+ raise UnknownTransformError(msg)
261
+
262
+ return getattr(module, name)
263
+
264
+
265
+ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
266
+ return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"}
267
+
268
+
269
+ @runtime_checkable
270
+ class VariableFactory(Protocol):
271
+ """Protocol for something that works like a Prior class."""
272
+
273
+ dims: tuple[str, ...]
274
+
275
+ def create_variable(self, name: str) -> pt.TensorVariable:
276
+ """Create a TensorVariable."""
277
+
278
+
279
+ def sample_prior(
280
+ factory: VariableFactory,
281
+ coords=None,
282
+ name: str = "variable",
283
+ wrap: bool = False,
284
+ **sample_prior_predictive_kwargs,
285
+ ) -> xr.Dataset:
286
+ """Sample the prior for an arbitrary VariableFactory.
287
+
288
+ Parameters
289
+ ----------
290
+ factory : VariableFactory
291
+ The factory to sample from.
292
+ coords : dict[str, list[str]], optional
293
+ The coordinates for the variable, by default None.
294
+ Only required if the dims are specified.
295
+ name : str, optional
296
+ The name of the variable, by default "variable".
297
+ wrap : bool, optional
298
+ Whether to wrap the variable in a `pm.Deterministic` node, by default False.
299
+ sample_prior_predictive_kwargs : dict
300
+ Additional arguments to pass to `pm.sample_prior_predictive`.
301
+
302
+ Returns
303
+ -------
304
+ xr.Dataset
305
+ The dataset of the prior samples.
306
+
307
+ Example
308
+ -------
309
+ Sample from an arbitrary variable factory.
310
+
311
+ .. code-block:: python
312
+
313
+ import pymc as pm
314
+
315
+ import pytensor.tensor as pt
316
+
317
+ from pymc_extras.prior import sample_prior
318
+
319
+ class CustomVariableDefinition:
320
+ def __init__(self, dims, n: int):
321
+ self.dims = dims
322
+ self.n = n
323
+
324
+ def create_variable(self, name: str) -> "TensorVariable":
325
+ x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
326
+ return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
327
+
328
+ cubic = CustomVariableDefinition(dims=("channel",), n=3)
329
+ coords = {"channel": ["C1", "C2", "C3"]}
330
+ # Doesn't include the return value
331
+ prior = sample_prior(cubic, coords=coords)
332
+
333
+ prior_with = sample_prior(cubic, coords=coords, wrap=True)
334
+
335
+ """
336
+ coords = coords or {}
337
+
338
+ if isinstance(factory.dims, str):
339
+ dims = (factory.dims,)
340
+ else:
341
+ dims = factory.dims
342
+
343
+ if missing_keys := set(dims) - set(coords.keys()):
344
+ raise KeyError(f"Coords are missing the following dims: {missing_keys}")
345
+
346
+ with pm.Model(coords=coords) as model:
347
+ if wrap:
348
+ pm.Deterministic(name, factory.create_variable(name), dims=factory.dims)
349
+ else:
350
+ factory.create_variable(name)
351
+
352
+ return pm.sample_prior_predictive(
353
+ model=model,
354
+ **sample_prior_predictive_kwargs,
355
+ ).prior
356
+
357
+
358
+ class Prior:
359
+ """A class to represent a prior distribution.
360
+
361
+ Make use of the various helper methods to understand the distributions
362
+ better.
363
+
364
+ - `preliz` attribute to get the equivalent distribution in `preliz`
365
+ - `sample_prior` method to sample from the prior
366
+ - `to_graph` get a dummy model graph with the distribution
367
+ - `constrain` to shift the distribution to a different range
368
+
369
+ Parameters
370
+ ----------
371
+ distribution : str
372
+ The name of PyMC distribution.
373
+ dims : Dims, optional
374
+ The dimensions of the variable, by default None
375
+ centered : bool, optional
376
+ Whether the variable is centered or not, by default True.
377
+ Only allowed for Normal distribution.
378
+ transform : str, optional
379
+ The name of the transform to apply to the variable after it is
380
+ created, by default None or no transform. The transformation must
381
+ be registered with `register_tensor_transform` function or
382
+ be available in either `pytensor.tensor` or `pymc.math`.
383
+
384
+ """
385
+
386
+ # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
387
+ non_centered_distributions: dict[str, dict[str, float]] = {
388
+ "Normal": {"mu": 0, "sigma": 1},
389
+ "StudentT": {"mu": 0, "sigma": 1},
390
+ "ZeroSumNormal": {"sigma": 1},
391
+ }
392
+
393
+ pymc_distribution: type[pm.Distribution]
394
+ pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
395
+
396
+ @validate_call
397
+ def __init__(
398
+ self,
399
+ distribution: str,
400
+ *,
401
+ dims: Dims | None = None,
402
+ centered: bool = True,
403
+ transform: str | None = None,
404
+ **parameters,
405
+ ) -> None:
406
+ self.distribution = distribution
407
+ self.parameters = parameters
408
+ self.dims = dims
409
+ self.centered = centered
410
+ self.transform = transform
411
+
412
+ self._checks()
413
+
414
+ @property
415
+ def distribution(self) -> str:
416
+ """The name of the PyMC distribution."""
417
+ return self._distribution
418
+
419
+ @distribution.setter
420
+ def distribution(self, distribution: str) -> None:
421
+ if hasattr(self, "_distribution"):
422
+ raise AttributeError("Can't change the distribution")
423
+
424
+ self._distribution = distribution
425
+ self.pymc_distribution = _get_pymc_distribution(distribution)
426
+
427
+ @property
428
+ def transform(self) -> str | None:
429
+ """The name of the transform to apply to the variable after it is created."""
430
+ return self._transform
431
+
432
+ @transform.setter
433
+ def transform(self, transform: str | None) -> None:
434
+ self._transform = transform
435
+ self.pytensor_transform = not transform or _get_transform(transform) # type: ignore
436
+
437
+ @property
438
+ def dims(self) -> Dims:
439
+ """The dimensions of the variable."""
440
+ return self._dims
441
+
442
+ @dims.setter
443
+ def dims(self, dims) -> None:
444
+ if isinstance(dims, str):
445
+ dims = (dims,)
446
+
447
+ if isinstance(dims, list):
448
+ dims = tuple(dims)
449
+
450
+ self._dims = dims or ()
451
+
452
+ self._param_dims_work()
453
+ self._unique_dims()
454
+
455
+ def __getitem__(self, key: str) -> Prior | Any:
456
+ """Return the parameter of the prior."""
457
+ return self.parameters[key]
458
+
459
+ def _checks(self) -> None:
460
+ if not self.centered:
461
+ self._correct_non_centered_distribution()
462
+
463
+ self._parameters_are_at_least_subset_of_pymc()
464
+ self._convert_lists_to_numpy()
465
+ self._parameters_are_correct_type()
466
+
467
+ def _parameters_are_at_least_subset_of_pymc(self) -> None:
468
+ pymc_params = _get_pymc_parameters(self.pymc_distribution)
469
+ if not set(self.parameters.keys()).issubset(pymc_params):
470
+ msg = (
471
+ f"Parameters {set(self.parameters.keys())} "
472
+ "are not a subset of the pymc distribution "
473
+ f"parameters {set(pymc_params)}"
474
+ )
475
+ raise ValueError(msg)
476
+
477
+ def _convert_lists_to_numpy(self) -> None:
478
+ def convert(x):
479
+ if not isinstance(x, list):
480
+ return x
481
+
482
+ return np.array(x)
483
+
484
+ self.parameters = {key: convert(value) for key, value in self.parameters.items()}
485
+
486
+ def _parameters_are_correct_type(self) -> None:
487
+ supported_types = (
488
+ int,
489
+ float,
490
+ np.ndarray,
491
+ Prior,
492
+ pt.TensorVariable,
493
+ VariableFactory,
494
+ )
495
+
496
+ incorrect_types = {
497
+ param: type(value)
498
+ for param, value in self.parameters.items()
499
+ if not isinstance(value, supported_types)
500
+ }
501
+ if incorrect_types:
502
+ msg = (
503
+ "Parameters must be one of the following types: "
504
+ f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}"
505
+ )
506
+ raise ValueError(msg)
507
+
508
+ def _correct_non_centered_distribution(self) -> None:
509
+ if not self.centered and self.distribution not in self.non_centered_distributions:
510
+ raise UnsupportedParameterizationError(
511
+ f"{self.distribution!r} is not supported for non-centered parameterization. "
512
+ f"Choose from {list(self.non_centered_distributions.keys())}"
513
+ )
514
+
515
+ required_parameters = set(self.non_centered_distributions[self.distribution].keys())
516
+
517
+ if set(self.parameters.keys()) < required_parameters:
518
+ msg = " and ".join([f"{param!r}" for param in required_parameters])
519
+ raise ValueError(
520
+ f"Must have at least {msg} parameter for non-centered for {self.distribution!r}"
521
+ )
522
+
523
+ def _unique_dims(self) -> None:
524
+ if not self.dims:
525
+ return
526
+
527
+ if len(self.dims) != len(set(self.dims)):
528
+ raise ValueError("Dims must be unique")
529
+
530
+ def _param_dims_work(self) -> None:
531
+ other_dims = set()
532
+ for value in self.parameters.values():
533
+ if hasattr(value, "dims"):
534
+ other_dims.update(value.dims)
535
+
536
+ if not other_dims.issubset(self.dims):
537
+ raise UnsupportedShapeError(
538
+ f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
539
+ )
540
+
541
+ def __str__(self) -> str:
542
+ """Return a string representation of the prior."""
543
+ param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()])
544
+ param_str = "" if not param_str else f", {param_str}"
545
+
546
+ dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else ""
547
+ centered_str = f", centered={self.centered}" if not self.centered else ""
548
+ transform_str = f', transform="{self.transform}"' if self.transform else ""
549
+ return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})'
550
+
551
+ def __repr__(self) -> str:
552
+ """Return a string representation of the prior."""
553
+ return f"{self}"
554
+
555
+ def _create_parameter(self, param, value, name):
556
+ if not hasattr(value, "create_variable"):
557
+ return value
558
+
559
+ child_name = f"{name}_{param}"
560
+ return self.dim_handler(value.create_variable(child_name), value.dims)
561
+
562
+ def _create_centered_variable(self, name: str):
563
+ parameters = {
564
+ param: self._create_parameter(param, value, name)
565
+ for param, value in self.parameters.items()
566
+ }
567
+ return self.pymc_distribution(name, **parameters, dims=self.dims)
568
+
569
+ def _create_non_centered_variable(self, name: str) -> pt.TensorVariable:
570
+ def handle_variable(var_name: str):
571
+ parameter = self.parameters[var_name]
572
+ if not hasattr(parameter, "create_variable"):
573
+ return parameter
574
+
575
+ return self.dim_handler(
576
+ parameter.create_variable(f"{name}_{var_name}"),
577
+ parameter.dims,
578
+ )
579
+
580
+ defaults = self.non_centered_distributions[self.distribution]
581
+ other_parameters = {
582
+ param: handle_variable(param)
583
+ for param in self.parameters.keys()
584
+ if param not in defaults
585
+ }
586
+ offset = self.pymc_distribution(
587
+ f"{name}_offset",
588
+ **defaults,
589
+ **other_parameters,
590
+ dims=self.dims,
591
+ )
592
+ if "mu" in self.parameters:
593
+ mu = (
594
+ handle_variable("mu")
595
+ if isinstance(self.parameters["mu"], Prior)
596
+ else self.parameters["mu"]
597
+ )
598
+ else:
599
+ mu = 0
600
+
601
+ sigma = (
602
+ handle_variable("sigma")
603
+ if isinstance(self.parameters["sigma"], Prior)
604
+ else self.parameters["sigma"]
605
+ )
606
+
607
+ return pm.Deterministic(
608
+ name,
609
+ mu + sigma * offset,
610
+ dims=self.dims,
611
+ )
612
+
613
+ def create_variable(self, name: str) -> pt.TensorVariable:
614
+ """Create a PyMC variable from the prior.
615
+
616
+ Must be used in a PyMC model context.
617
+
618
+ Parameters
619
+ ----------
620
+ name : str
621
+ The name of the variable.
622
+
623
+ Returns
624
+ -------
625
+ pt.TensorVariable
626
+ The PyMC variable.
627
+
628
+ Examples
629
+ --------
630
+ Create a hierarchical normal variable in larger PyMC model.
631
+
632
+ .. code-block:: python
633
+
634
+ dist = Prior(
635
+ "Normal",
636
+ mu=Prior("Normal"),
637
+ sigma=Prior("HalfNormal"),
638
+ dims="channel",
639
+ )
640
+
641
+ coords = {"channel": ["C1", "C2", "C3"]}
642
+ with pm.Model(coords=coords):
643
+ var = dist.create_variable("var")
644
+
645
+ """
646
+ self.dim_handler = create_dim_handler(self.dims)
647
+
648
+ if self.transform:
649
+ var_name = f"{name}_raw"
650
+
651
+ def transform(var):
652
+ return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims)
653
+ else:
654
+ var_name = name
655
+
656
+ def transform(var):
657
+ return var
658
+
659
+ create_variable = (
660
+ self._create_centered_variable if self.centered else self._create_non_centered_variable
661
+ )
662
+ var = create_variable(name=var_name)
663
+ return transform(var)
664
+
665
+ @property
666
+ def preliz(self):
667
+ """Create an equivalent preliz distribution.
668
+
669
+ Helpful to visualize a distribution when it is univariate.
670
+
671
+ Returns
672
+ -------
673
+ preliz.distributions.Distribution
674
+
675
+ Examples
676
+ --------
677
+ Create a preliz distribution from a prior.
678
+
679
+ .. code-block:: python
680
+
681
+ from pymc_extras.prior import Prior
682
+
683
+ dist = Prior("Gamma", alpha=5, beta=1)
684
+ dist.preliz.plot_pdf()
685
+
686
+ """
687
+ import preliz as pz
688
+
689
+ return getattr(pz, self.distribution)(**self.parameters)
690
+
691
+ def to_dict(self) -> dict[str, Any]:
692
+ """Convert the prior to dictionary format.
693
+
694
+ Returns
695
+ -------
696
+ dict[str, Any]
697
+ The dictionary format of the prior.
698
+
699
+ Examples
700
+ --------
701
+ Convert a prior to the dictionary format.
702
+
703
+ .. code-block:: python
704
+
705
+ from pymc_extras.prior import Prior
706
+
707
+ dist = Prior("Normal", mu=0, sigma=1)
708
+
709
+ dist.to_dict()
710
+
711
+ Convert a hierarchical prior to the dictionary format.
712
+
713
+ .. code-block:: python
714
+
715
+ dist = Prior(
716
+ "Normal",
717
+ mu=Prior("Normal"),
718
+ sigma=Prior("HalfNormal"),
719
+ dims="channel",
720
+ )
721
+
722
+ dist.to_dict()
723
+
724
+ """
725
+ data: dict[str, Any] = {
726
+ "dist": self.distribution,
727
+ }
728
+ if self.parameters:
729
+
730
+ def handle_value(value):
731
+ if isinstance(value, Prior):
732
+ return value.to_dict()
733
+
734
+ if isinstance(value, pt.TensorVariable):
735
+ value = value.eval()
736
+
737
+ if isinstance(value, np.ndarray):
738
+ return value.tolist()
739
+
740
+ if hasattr(value, "to_dict"):
741
+ return value.to_dict()
742
+
743
+ return value
744
+
745
+ data["kwargs"] = {
746
+ param: handle_value(value) for param, value in self.parameters.items()
747
+ }
748
+ if not self.centered:
749
+ data["centered"] = False
750
+
751
+ if self.dims:
752
+ data["dims"] = self.dims
753
+
754
+ if self.transform:
755
+ data["transform"] = self.transform
756
+
757
+ return data
758
+
759
+ @classmethod
760
+ def from_dict(cls, data) -> Prior:
761
+ """Create a Prior from the dictionary format.
762
+
763
+ Parameters
764
+ ----------
765
+ data : dict[str, Any]
766
+ The dictionary format of the prior.
767
+
768
+ Returns
769
+ -------
770
+ Prior
771
+ The prior distribution.
772
+
773
+ Examples
774
+ --------
775
+ Convert prior in the dictionary format to a Prior instance.
776
+
777
+ .. code-block:: python
778
+
779
+ from pymc_extras.prior import Prior
780
+
781
+ data = {
782
+ "dist": "Normal",
783
+ "kwargs": {"mu": 0, "sigma": 1},
784
+ }
785
+
786
+ dist = Prior.from_dict(data)
787
+ dist
788
+ # Prior("Normal", mu=0, sigma=1)
789
+
790
+ """
791
+ if not isinstance(data, dict):
792
+ msg = (
793
+ "Must be a dictionary representation of a prior distribution. "
794
+ f"Not of type: {type(data)}"
795
+ )
796
+ raise ValueError(msg)
797
+
798
+ dist = data["dist"]
799
+ kwargs = data.get("kwargs", {})
800
+
801
+ def handle_value(value):
802
+ if isinstance(value, dict):
803
+ return deserialize(value)
804
+
805
+ if isinstance(value, list):
806
+ return np.array(value)
807
+
808
+ return value
809
+
810
+ kwargs = {param: handle_value(value) for param, value in kwargs.items()}
811
+ centered = data.get("centered", True)
812
+ dims = data.get("dims")
813
+ if isinstance(dims, list):
814
+ dims = tuple(dims)
815
+ transform = data.get("transform")
816
+
817
+ return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
818
+
819
+ def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior:
820
+ """Create a new prior with a given mass constrained within the given bounds.
821
+
822
+ Wrapper around `preliz.maxent`.
823
+
824
+ Parameters
825
+ ----------
826
+ lower : float
827
+ The lower bound.
828
+ upper : float
829
+ The upper bound.
830
+ mass: float = 0.95
831
+ The mass of the distribution to keep within the bounds.
832
+ kwargs : dict
833
+ Additional arguments to pass to `pz.maxent`.
834
+
835
+ Returns
836
+ -------
837
+ Prior
838
+ The maximum entropy prior with a mass constrained to the given bounds.
839
+
840
+ Examples
841
+ --------
842
+ Create a Beta distribution that is constrained to have 95% of the mass
843
+ between 0.5 and 0.8.
844
+
845
+ .. code-block:: python
846
+
847
+ dist = Prior(
848
+ "Beta",
849
+ ).constrain(lower=0.5, upper=0.8)
850
+
851
+ Create a Beta distribution with mean 0.6, that is constrained to
852
+ have 95% of the mass between 0.5 and 0.8.
853
+
854
+ .. code-block:: python
855
+
856
+ dist = Prior(
857
+ "Beta",
858
+ mu=0.6,
859
+ ).constrain(lower=0.5, upper=0.8)
860
+
861
+ """
862
+ from preliz import maxent
863
+
864
+ if self.transform:
865
+ raise ValueError("Can't constrain a transformed variable")
866
+
867
+ if kwargs is None:
868
+ kwargs = {}
869
+ kwargs.setdefault("plot", False)
870
+
871
+ if kwargs["plot"]:
872
+ new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict
873
+ else:
874
+ new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict
875
+
876
+ return Prior(
877
+ self.distribution,
878
+ dims=self.dims,
879
+ transform=self.transform,
880
+ centered=self.centered,
881
+ **new_parameters,
882
+ )
883
+
884
+ def __eq__(self, other) -> bool:
885
+ """Check if two priors are equal."""
886
+ if not isinstance(other, Prior):
887
+ return False
888
+
889
+ try:
890
+ np.testing.assert_equal(self.parameters, other.parameters)
891
+ except AssertionError:
892
+ return False
893
+
894
+ return (
895
+ self.distribution == other.distribution
896
+ and self.dims == other.dims
897
+ and self.centered == other.centered
898
+ and self.transform == other.transform
899
+ )
900
+
901
+ def sample_prior(
902
+ self,
903
+ coords=None,
904
+ name: str = "variable",
905
+ **sample_prior_predictive_kwargs,
906
+ ) -> xr.Dataset:
907
+ """Sample the prior distribution for the variable.
908
+
909
+ Parameters
910
+ ----------
911
+ coords : dict[str, list[str]], optional
912
+ The coordinates for the variable, by default None.
913
+ Only required if the dims are specified.
914
+ name : str, optional
915
+ The name of the variable, by default "variable".
916
+ sample_prior_predictive_kwargs : dict
917
+ Additional arguments to pass to `pm.sample_prior_predictive`.
918
+
919
+ Returns
920
+ -------
921
+ xr.Dataset
922
+ The dataset of the prior samples.
923
+
924
+ Example
925
+ -------
926
+ Sample from a hierarchical normal distribution.
927
+
928
+ .. code-block:: python
929
+
930
+ dist = Prior(
931
+ "Normal",
932
+ mu=Prior("Normal"),
933
+ sigma=Prior("HalfNormal"),
934
+ dims="channel",
935
+ )
936
+
937
+ coords = {"channel": ["C1", "C2", "C3"]}
938
+ prior = dist.sample_prior(coords=coords)
939
+
940
+ """
941
+ return sample_prior(
942
+ factory=self,
943
+ coords=coords,
944
+ name=name,
945
+ **sample_prior_predictive_kwargs,
946
+ )
947
+
948
+ def __deepcopy__(self, memo) -> Prior:
949
+ """Return a deep copy of the prior."""
950
+ if id(self) in memo:
951
+ return memo[id(self)]
952
+
953
+ copy_obj = Prior(
954
+ self.distribution,
955
+ dims=copy.copy(self.dims),
956
+ centered=self.centered,
957
+ transform=self.transform,
958
+ **copy.deepcopy(self.parameters),
959
+ )
960
+ memo[id(self)] = copy_obj
961
+ return copy_obj
962
+
963
+ def deepcopy(self) -> Prior:
964
+ """Return a deep copy of the prior."""
965
+ return copy.deepcopy(self)
966
+
967
+ def to_graph(self):
968
+ """Generate a graph of the variables.
969
+
970
+ Examples
971
+ --------
972
+ Create the graph for a 2D transformed hierarchical distribution.
973
+
974
+ .. code-block:: python
975
+
976
+ from pymc_extras.prior import Prior
977
+
978
+ mu = Prior(
979
+ "Normal",
980
+ mu=Prior("Normal"),
981
+ sigma=Prior("HalfNormal"),
982
+ dims="channel",
983
+ )
984
+ sigma = Prior("HalfNormal", dims="channel")
985
+ dist = Prior(
986
+ "Normal",
987
+ mu=mu,
988
+ sigma=sigma,
989
+ dims=("channel", "geo"),
990
+ centered=False,
991
+ transform="sigmoid",
992
+ )
993
+
994
+ dist.to_graph()
995
+
996
+ .. image:: /_static/example-graph.png
997
+ :alt: Example graph
998
+
999
+ """
1000
+ coords = {name: ["DUMMY"] for name in self.dims}
1001
+ with pm.Model(coords=coords) as model:
1002
+ self.create_variable("var")
1003
+
1004
+ return pm.model_to_graphviz(model)
1005
+
1006
+ def create_likelihood_variable(
1007
+ self,
1008
+ name: str,
1009
+ mu: pt.TensorLike,
1010
+ observed: pt.TensorLike,
1011
+ ) -> pt.TensorVariable:
1012
+ """Create a likelihood variable from the prior.
1013
+
1014
+ Will require that the distribution has a `mu` parameter
1015
+ and that it has not been set in the parameters.
1016
+
1017
+ Parameters
1018
+ ----------
1019
+ name : str
1020
+ The name of the variable.
1021
+ mu : pt.TensorLike
1022
+ The mu parameter for the likelihood.
1023
+ observed : pt.TensorLike
1024
+ The observed data.
1025
+
1026
+ Returns
1027
+ -------
1028
+ pt.TensorVariable
1029
+ The PyMC variable.
1030
+
1031
+ Examples
1032
+ --------
1033
+ Create a likelihood variable in a larger PyMC model.
1034
+
1035
+ .. code-block:: python
1036
+
1037
+ import pymc as pm
1038
+
1039
+ dist = Prior("Normal", sigma=Prior("HalfNormal"))
1040
+
1041
+ with pm.Model():
1042
+ # Create the likelihood variable
1043
+ mu = pm.Normal("mu", mu=0, sigma=1)
1044
+ dist.create_likelihood_variable("y", mu=mu, observed=observed)
1045
+
1046
+ """
1047
+ if "mu" not in _get_pymc_parameters(self.pymc_distribution):
1048
+ raise UnsupportedDistributionError(
1049
+ f"Likelihood distribution {self.distribution!r} is not supported."
1050
+ )
1051
+
1052
+ if "mu" in self.parameters:
1053
+ raise MuAlreadyExistsError(self)
1054
+
1055
+ distribution = self.deepcopy()
1056
+ distribution.parameters["mu"] = mu
1057
+ distribution.parameters["observed"] = observed
1058
+ return distribution.create_variable(name)
1059
+
1060
+
1061
+ class VariableNotFound(Exception):
1062
+ """Variable is not found."""
1063
+
1064
+
1065
+ def _remove_random_variable(var: pt.TensorVariable) -> None:
1066
+ if var.name is None:
1067
+ raise ValueError("This isn't removable")
1068
+
1069
+ name: str = var.name
1070
+
1071
+ model = pm.modelcontext(None)
1072
+ for idx, free_rv in enumerate(model.free_RVs):
1073
+ if var == free_rv:
1074
+ index_to_remove = idx
1075
+ break
1076
+ else:
1077
+ raise VariableNotFound(f"Variable {var.name!r} not found")
1078
+
1079
+ var.name = None
1080
+ model.free_RVs.pop(index_to_remove)
1081
+ model.named_vars.pop(name)
1082
+
1083
+
1084
+ @dataclass
1085
+ class Censored:
1086
+ """Create censored random variable.
1087
+
1088
+ Examples
1089
+ --------
1090
+ Create a censored Normal distribution:
1091
+
1092
+ .. code-block:: python
1093
+
1094
+ from pymc_extras.prior import Prior, Censored
1095
+
1096
+ normal = Prior("Normal")
1097
+ censored_normal = Censored(normal, lower=0)
1098
+
1099
+ Create hierarchical censored Normal distribution:
1100
+
1101
+ .. code-block:: python
1102
+
1103
+ from pymc_extras.prior import Prior, Censored
1104
+
1105
+ normal = Prior(
1106
+ "Normal",
1107
+ mu=Prior("Normal"),
1108
+ sigma=Prior("HalfNormal"),
1109
+ dims="channel",
1110
+ )
1111
+ censored_normal = Censored(normal, lower=0)
1112
+
1113
+ coords = {"channel": range(3)}
1114
+ samples = censored_normal.sample_prior(coords=coords)
1115
+
1116
+ """
1117
+
1118
+ distribution: InstanceOf[Prior]
1119
+ lower: float | InstanceOf[pt.TensorVariable] = -np.inf
1120
+ upper: float | InstanceOf[pt.TensorVariable] = np.inf
1121
+
1122
+ def __post_init__(self) -> None:
1123
+ """Check validity at initialization."""
1124
+ if not self.distribution.centered:
1125
+ raise ValueError(
1126
+ "Censored distribution must be centered so that .dist() API can be used on distribution."
1127
+ )
1128
+
1129
+ if self.distribution.transform is not None:
1130
+ raise ValueError(
1131
+ "Censored distribution can't have a transform so that .dist() API can be used on distribution."
1132
+ )
1133
+
1134
+ @property
1135
+ def dims(self) -> tuple[str, ...]:
1136
+ """The dims from the distribution to censor."""
1137
+ return self.distribution.dims
1138
+
1139
+ @dims.setter
1140
+ def dims(self, dims) -> None:
1141
+ self.distribution.dims = dims
1142
+
1143
+ def create_variable(self, name: str) -> pt.TensorVariable:
1144
+ """Create censored random variable."""
1145
+ dist = self.distribution.create_variable(name)
1146
+ _remove_random_variable(var=dist)
1147
+
1148
+ return pm.Censored(
1149
+ name,
1150
+ dist,
1151
+ lower=self.lower,
1152
+ upper=self.upper,
1153
+ dims=self.dims,
1154
+ )
1155
+
1156
+ def to_dict(self) -> dict[str, Any]:
1157
+ """Convert the censored distribution to a dictionary."""
1158
+
1159
+ def handle_value(value):
1160
+ if isinstance(value, pt.TensorVariable):
1161
+ return value.eval().tolist()
1162
+
1163
+ return value
1164
+
1165
+ return {
1166
+ "class": "Censored",
1167
+ "data": {
1168
+ "dist": self.distribution.to_dict(),
1169
+ "lower": handle_value(self.lower),
1170
+ "upper": handle_value(self.upper),
1171
+ },
1172
+ }
1173
+
1174
+ @classmethod
1175
+ def from_dict(cls, data: dict[str, Any]) -> Censored:
1176
+ """Create a censored distribution from a dictionary."""
1177
+ data = data["data"]
1178
+ return cls( # type: ignore
1179
+ distribution=deserialize(data["dist"]),
1180
+ lower=data["lower"],
1181
+ upper=data["upper"],
1182
+ )
1183
+
1184
+ def sample_prior(
1185
+ self,
1186
+ coords=None,
1187
+ name: str = "variable",
1188
+ **sample_prior_predictive_kwargs,
1189
+ ) -> xr.Dataset:
1190
+ """Sample the prior distribution for the variable.
1191
+
1192
+ Parameters
1193
+ ----------
1194
+ coords : dict[str, list[str]], optional
1195
+ The coordinates for the variable, by default None.
1196
+ Only required if the dims are specified.
1197
+ name : str, optional
1198
+ The name of the variable, by default "var".
1199
+ sample_prior_predictive_kwargs : dict
1200
+ Additional arguments to pass to `pm.sample_prior_predictive`.
1201
+
1202
+ Returns
1203
+ -------
1204
+ xr.Dataset
1205
+ The dataset of the prior samples.
1206
+
1207
+ Example
1208
+ -------
1209
+ Sample from a censored Gamma distribution.
1210
+
1211
+ .. code-block:: python
1212
+
1213
+ gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
1214
+ dist = Censored(gamma, lower=0.5)
1215
+
1216
+ coords = {"channel": ["C1", "C2", "C3"]}
1217
+ prior = dist.sample_prior(coords=coords)
1218
+
1219
+ """
1220
+ return sample_prior(
1221
+ factory=self,
1222
+ coords=coords,
1223
+ name=name,
1224
+ **sample_prior_predictive_kwargs,
1225
+ )
1226
+
1227
+ def to_graph(self):
1228
+ """Generate a graph of the variables.
1229
+
1230
+ Examples
1231
+ --------
1232
+ Create graph for a censored Normal distribution
1233
+
1234
+ .. code-block:: python
1235
+
1236
+ from pymc_extras.prior import Prior, Censored
1237
+
1238
+ normal = Prior("Normal")
1239
+ censored_normal = Censored(normal, lower=0)
1240
+
1241
+ censored_normal.to_graph()
1242
+
1243
+ """
1244
+ coords = {name: ["DUMMY"] for name in self.dims}
1245
+ with pm.Model(coords=coords) as model:
1246
+ self.create_variable("var")
1247
+
1248
+ return pm.model_to_graphviz(model)
1249
+
1250
+ def create_likelihood_variable(
1251
+ self,
1252
+ name: str,
1253
+ mu: pt.TensorLike,
1254
+ observed: pt.TensorLike,
1255
+ ) -> pt.TensorVariable:
1256
+ """Create observed censored variable.
1257
+
1258
+ Will require that the distribution has a `mu` parameter
1259
+ and that it has not been set in the parameters.
1260
+
1261
+ Parameters
1262
+ ----------
1263
+ name : str
1264
+ The name of the variable.
1265
+ mu : pt.TensorLike
1266
+ The mu parameter for the likelihood.
1267
+ observed : pt.TensorLike
1268
+ The observed data.
1269
+
1270
+ Returns
1271
+ -------
1272
+ pt.TensorVariable
1273
+ The PyMC variable.
1274
+
1275
+ Examples
1276
+ --------
1277
+ Create a censored likelihood variable in a larger PyMC model.
1278
+
1279
+ .. code-block:: python
1280
+
1281
+ import pymc as pm
1282
+ from pymc_extras.prior import Prior, Censored
1283
+
1284
+ normal = Prior("Normal", sigma=Prior("HalfNormal"))
1285
+ dist = Censored(normal, lower=0)
1286
+
1287
+ observed = 1
1288
+
1289
+ with pm.Model():
1290
+ # Create the likelihood variable
1291
+ mu = pm.HalfNormal("mu", sigma=1)
1292
+ dist.create_likelihood_variable("y", mu=mu, observed=observed)
1293
+
1294
+ """
1295
+ if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
1296
+ raise UnsupportedDistributionError(
1297
+ f"Likelihood distribution {self.distribution.distribution!r} is not supported."
1298
+ )
1299
+
1300
+ if "mu" in self.distribution.parameters:
1301
+ raise MuAlreadyExistsError(self.distribution)
1302
+
1303
+ distribution = self.distribution.deepcopy()
1304
+ distribution.parameters["mu"] = mu
1305
+
1306
+ dist = distribution.create_variable(name)
1307
+ _remove_random_variable(var=dist)
1308
+
1309
+ return pm.Censored(
1310
+ name,
1311
+ dist,
1312
+ observed=observed,
1313
+ lower=self.lower,
1314
+ upper=self.upper,
1315
+ dims=self.dims,
1316
+ )
1317
+
1318
+
1319
+ class Scaled:
1320
+ """Scaled distribution for numerical stability."""
1321
+
1322
+ def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
1323
+ self.dist = dist
1324
+ self.factor = factor
1325
+
1326
+ @property
1327
+ def dims(self) -> Dims:
1328
+ """The dimensions of the scaled distribution."""
1329
+ return self.dist.dims
1330
+
1331
+ def create_variable(self, name: str) -> pt.TensorVariable:
1332
+ """Create a scaled variable.
1333
+
1334
+ Parameters
1335
+ ----------
1336
+ name : str
1337
+ The name of the variable.
1338
+
1339
+ Returns
1340
+ -------
1341
+ pt.TensorVariable
1342
+ The scaled variable.
1343
+ """
1344
+ var = self.dist.create_variable(f"{name}_unscaled")
1345
+ return pm.Deterministic(name, var * self.factor, dims=self.dims)
1346
+
1347
+
1348
+ def _is_prior_type(data: dict) -> bool:
1349
+ return "dist" in data
1350
+
1351
+
1352
+ def _is_censored_type(data: dict) -> bool:
1353
+ return data.keys() == {"class", "data"} and data["class"] == "Censored"
1354
+
1355
+
1356
+ register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
1357
+ register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
1358
+
1359
+
1360
+ def __getattr__(name: str):
1361
+ """Get Prior class through the module.
1362
+
1363
+ Examples
1364
+ --------
1365
+ Create a normal distribution.
1366
+
1367
+ .. code-block:: python
1368
+
1369
+ from pymc_extras.prior import Normal
1370
+
1371
+ dist = Normal(mu=1, sigma=2)
1372
+
1373
+ Create a hierarchical normal distribution.
1374
+
1375
+ .. code-block:: python
1376
+
1377
+ import pymc_extras.prior as pr
1378
+
1379
+ dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel")
1380
+ samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
1381
+
1382
+ """
1383
+ # Protect against doctest
1384
+ if name == "__wrapped__":
1385
+ return
1386
+
1387
+ _get_pymc_distribution(name)
1388
+ return partial(Prior, distribution=name)