pymc-extras 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymc_extras/prior.py ADDED
@@ -0,0 +1,1356 @@
1
+ """Class that represents a prior distribution.
2
+
3
+ The `Prior` class is a wrapper around PyMC distributions that allows the user
4
+ to create outside of the PyMC model.
5
+
6
+ Examples
7
+ --------
8
+ Create a normal prior.
9
+
10
+ .. code-block:: python
11
+
12
+ from pymc_extras.prior import Prior
13
+
14
+ normal = Prior("Normal")
15
+
16
+ Create a hierarchical normal prior by using distributions for the parameters
17
+ and specifying the dims.
18
+
19
+ .. code-block:: python
20
+
21
+ hierarchical_normal = Prior(
22
+ "Normal",
23
+ mu=Prior("Normal"),
24
+ sigma=Prior("HalfNormal"),
25
+ dims="channel",
26
+ )
27
+
28
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
29
+
30
+ .. code-block:: python
31
+
32
+ non_centered_hierarchical_normal = Prior(
33
+ "Normal",
34
+ mu=Prior("Normal"),
35
+ sigma=Prior("HalfNormal"),
36
+ dims="channel",
37
+ # Only change needed to make it non-centered
38
+ centered=False,
39
+ )
40
+
41
+ Create a hierarchical beta prior by using Beta distribution, distributions for
42
+ the parameters, and specifying the dims.
43
+
44
+ .. code-block:: python
45
+
46
+ hierarchical_beta = Prior(
47
+ "Beta",
48
+ alpha=Prior("HalfNormal"),
49
+ beta=Prior("HalfNormal"),
50
+ dims="channel",
51
+ )
52
+
53
+ Create a transformed hierarchical normal prior by using the `transform`
54
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
55
+
56
+ .. code-block:: python
57
+
58
+ transformed_hierarchical_normal = Prior(
59
+ "Normal",
60
+ mu=Prior("Normal"),
61
+ sigma=Prior("HalfNormal"),
62
+ transform="sigmoid",
63
+ dims="channel",
64
+ )
65
+
66
+ Create a prior with a custom transform function by registering it with
67
+ `register_tensor_transform`.
68
+
69
+ .. code-block:: python
70
+
71
+ from pymc_extras.prior import register_tensor_transform
72
+
73
+ def custom_transform(x):
74
+ return x ** 2
75
+
76
+ register_tensor_transform("square", custom_transform)
77
+
78
+ custom_distribution = Prior("Normal", transform="square")
79
+
80
+ """
81
+
82
+ from __future__ import annotations
83
+
84
+ import copy
85
+
86
+ from collections.abc import Callable
87
+ from inspect import signature
88
+ from typing import Any, Protocol, runtime_checkable
89
+
90
+ import numpy as np
91
+ import pymc as pm
92
+ import pytensor.tensor as pt
93
+ import xarray as xr
94
+
95
+ from pydantic import InstanceOf, validate_call
96
+ from pydantic.dataclasses import dataclass
97
+ from pymc.distributions.shape_utils import Dims
98
+
99
+ from pymc_extras.deserialize import deserialize, register_deserialization
100
+
101
+
102
+ class UnsupportedShapeError(Exception):
103
+ """Error for when the shapes from variables are not compatible."""
104
+
105
+
106
+ class UnsupportedDistributionError(Exception):
107
+ """Error for when an unsupported distribution is used."""
108
+
109
+
110
+ class UnsupportedParameterizationError(Exception):
111
+ """The follow parameterization is not supported."""
112
+
113
+
114
+ class MuAlreadyExistsError(Exception):
115
+ """Error for when 'mu' is present in Prior."""
116
+
117
+ def __init__(self, distribution: Prior) -> None:
118
+ self.distribution = distribution
119
+ self.message = f"The mu parameter is already defined in {distribution}"
120
+ super().__init__(self.message)
121
+
122
+
123
+ class UnknownTransformError(Exception):
124
+ """Error for when an unknown transform is used."""
125
+
126
+
127
+ def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
128
+ """Remove leading 'x' from the args."""
129
+ while args and args[0] == "x":
130
+ args.pop(0)
131
+
132
+ return args
133
+
134
+
135
+ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
136
+ """Take a tensor of dims `dims` and align it to `desired_dims`.
137
+
138
+ Doesn't check for validity of the dims
139
+
140
+ Examples
141
+ --------
142
+ 1D to 2D with new dim
143
+
144
+ .. code-block:: python
145
+
146
+ x = np.array([1, 2, 3])
147
+ dims = "channel"
148
+
149
+ desired_dims = ("channel", "group")
150
+
151
+ handle_dims(x, dims, desired_dims)
152
+
153
+ """
154
+ x = pt.as_tensor_variable(x)
155
+
156
+ if np.ndim(x) == 0:
157
+ return x
158
+
159
+ dims = dims if isinstance(dims, tuple) else (dims,)
160
+ desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)
161
+
162
+ if difference := set(dims).difference(desired_dims):
163
+ raise UnsupportedShapeError(
164
+ f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
165
+ f"{difference} is missing from the desired dims."
166
+ )
167
+
168
+ aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)
169
+
170
+ missing_dims = aligned_dims.sum(axis=0) == 0
171
+ new_idx = aligned_dims.argmax(axis=0)
172
+
173
+ args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)]
174
+ args = _remove_leading_xs(args)
175
+ return x.dimshuffle(*args)
176
+
177
+
178
+ DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
179
+
180
+
181
+ def create_dim_handler(desired_dims: Dims) -> DimHandler:
182
+ """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
183
+
184
+ def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
185
+ return handle_dims(x, dims, desired_dims)
186
+
187
+ return func
188
+
189
+
190
+ def _dims_to_str(obj: tuple[str, ...]) -> str:
191
+ if len(obj) == 1:
192
+ return f'"{obj[0]}"'
193
+
194
+ return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")"
195
+
196
+
197
+ def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
198
+ if not hasattr(pm, name):
199
+ raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}")
200
+
201
+ return getattr(pm, name)
202
+
203
+
204
+ Transform = Callable[[pt.TensorLike], pt.TensorLike]
205
+
206
+ CUSTOM_TRANSFORMS: dict[str, Transform] = {}
207
+
208
+
209
+ def register_tensor_transform(name: str, transform: Transform) -> None:
210
+ """Register a tensor transform function to be used in the `Prior` class.
211
+
212
+ Parameters
213
+ ----------
214
+ name : str
215
+ The name of the transform.
216
+ func : Callable[[pt.TensorLike], pt.TensorLike]
217
+ The function to apply to the tensor.
218
+
219
+ Examples
220
+ --------
221
+ Register a custom transform function.
222
+
223
+ .. code-block:: python
224
+
225
+ from pymc_extras.prior import (
226
+ Prior,
227
+ register_tensor_transform,
228
+ )
229
+
230
+ def custom_transform(x):
231
+ return x ** 2
232
+
233
+ register_tensor_transform("square", custom_transform)
234
+
235
+ custom_distribution = Prior("Normal", transform="square")
236
+
237
+ """
238
+ CUSTOM_TRANSFORMS[name] = transform
239
+
240
+
241
+ def _get_transform(name: str):
242
+ if name in CUSTOM_TRANSFORMS:
243
+ return CUSTOM_TRANSFORMS[name]
244
+
245
+ for module in (pt, pm.math):
246
+ if hasattr(module, name):
247
+ break
248
+ else:
249
+ module = None
250
+
251
+ if not module:
252
+ msg = (
253
+ f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
254
+ "If this is a custom function, register it with the "
255
+ "`pymc_extras.prior.register_tensor_transform` function before "
256
+ "previous function call."
257
+ )
258
+
259
+ raise UnknownTransformError(msg)
260
+
261
+ return getattr(module, name)
262
+
263
+
264
+ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
265
+ return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"}
266
+
267
+
268
+ @runtime_checkable
269
+ class VariableFactory(Protocol):
270
+ """Protocol for something that works like a Prior class."""
271
+
272
+ dims: tuple[str, ...]
273
+
274
+ def create_variable(self, name: str) -> pt.TensorVariable:
275
+ """Create a TensorVariable."""
276
+
277
+
278
+ def sample_prior(
279
+ factory: VariableFactory,
280
+ coords=None,
281
+ name: str = "var",
282
+ wrap: bool = False,
283
+ **sample_prior_predictive_kwargs,
284
+ ) -> xr.Dataset:
285
+ """Sample the prior for an arbitrary VariableFactory.
286
+
287
+ Parameters
288
+ ----------
289
+ factory : VariableFactory
290
+ The factory to sample from.
291
+ coords : dict[str, list[str]], optional
292
+ The coordinates for the variable, by default None.
293
+ Only required if the dims are specified.
294
+ name : str, optional
295
+ The name of the variable, by default "var".
296
+ wrap : bool, optional
297
+ Whether to wrap the variable in a `pm.Deterministic` node, by default False.
298
+ sample_prior_predictive_kwargs : dict
299
+ Additional arguments to pass to `pm.sample_prior_predictive`.
300
+
301
+ Returns
302
+ -------
303
+ xr.Dataset
304
+ The dataset of the prior samples.
305
+
306
+ Example
307
+ -------
308
+ Sample from an arbitrary variable factory.
309
+
310
+ .. code-block:: python
311
+
312
+ import pymc as pm
313
+
314
+ import pytensor.tensor as pt
315
+
316
+ from pymc_extras.prior import sample_prior
317
+
318
+ class CustomVariableDefinition:
319
+ def __init__(self, dims, n: int):
320
+ self.dims = dims
321
+ self.n = n
322
+
323
+ def create_variable(self, name: str) -> "TensorVariable":
324
+ x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
325
+ return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
326
+
327
+ cubic = CustomVariableDefinition(dims=("channel",), n=3)
328
+ coords = {"channel": ["C1", "C2", "C3"]}
329
+ # Doesn't include the return value
330
+ prior = sample_prior(cubic, coords=coords)
331
+
332
+ prior_with = sample_prior(cubic, coords=coords, wrap=True)
333
+
334
+ """
335
+ coords = coords or {}
336
+
337
+ if isinstance(factory.dims, str):
338
+ dims = (factory.dims,)
339
+ else:
340
+ dims = factory.dims
341
+
342
+ if missing_keys := set(dims) - set(coords.keys()):
343
+ raise KeyError(f"Coords are missing the following dims: {missing_keys}")
344
+
345
+ with pm.Model(coords=coords) as model:
346
+ if wrap:
347
+ pm.Deterministic(name, factory.create_variable(name), dims=factory.dims)
348
+ else:
349
+ factory.create_variable(name)
350
+
351
+ return pm.sample_prior_predictive(
352
+ model=model,
353
+ **sample_prior_predictive_kwargs,
354
+ ).prior
355
+
356
+
357
+ class Prior:
358
+ """A class to represent a prior distribution.
359
+
360
+ Make use of the various helper methods to understand the distributions
361
+ better.
362
+
363
+ - `preliz` attribute to get the equivalent distribution in `preliz`
364
+ - `sample_prior` method to sample from the prior
365
+ - `graph` get a dummy model graph with the distribution
366
+ - `constrain` to shift the distribution to a different range
367
+
368
+ Parameters
369
+ ----------
370
+ distribution : str
371
+ The name of PyMC distribution.
372
+ dims : Dims, optional
373
+ The dimensions of the variable, by default None
374
+ centered : bool, optional
375
+ Whether the variable is centered or not, by default True.
376
+ Only allowed for Normal distribution.
377
+ transform : str, optional
378
+ The name of the transform to apply to the variable after it is
379
+ created, by default None or no transform. The transformation must
380
+ be registered with `register_tensor_transform` function or
381
+ be available in either `pytensor.tensor` or `pymc.math`.
382
+
383
+ """
384
+
385
+ # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
386
+ non_centered_distributions: dict[str, dict[str, float]] = {
387
+ "Normal": {"mu": 0, "sigma": 1},
388
+ "StudentT": {"mu": 0, "sigma": 1},
389
+ "ZeroSumNormal": {"sigma": 1},
390
+ }
391
+
392
+ pymc_distribution: type[pm.Distribution]
393
+ pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
394
+
395
+ @validate_call
396
+ def __init__(
397
+ self,
398
+ distribution: str,
399
+ *,
400
+ dims: Dims | None = None,
401
+ centered: bool = True,
402
+ transform: str | None = None,
403
+ **parameters,
404
+ ) -> None:
405
+ self.distribution = distribution
406
+ self.parameters = parameters
407
+ self.dims = dims
408
+ self.centered = centered
409
+ self.transform = transform
410
+
411
+ self._checks()
412
+
413
+ @property
414
+ def distribution(self) -> str:
415
+ """The name of the PyMC distribution."""
416
+ return self._distribution
417
+
418
+ @distribution.setter
419
+ def distribution(self, distribution: str) -> None:
420
+ if hasattr(self, "_distribution"):
421
+ raise AttributeError("Can't change the distribution")
422
+
423
+ self._distribution = distribution
424
+ self.pymc_distribution = _get_pymc_distribution(distribution)
425
+
426
+ @property
427
+ def transform(self) -> str | None:
428
+ """The name of the transform to apply to the variable after it is created."""
429
+ return self._transform
430
+
431
+ @transform.setter
432
+ def transform(self, transform: str | None) -> None:
433
+ self._transform = transform
434
+ self.pytensor_transform = not transform or _get_transform(transform) # type: ignore
435
+
436
+ @property
437
+ def dims(self) -> Dims:
438
+ """The dimensions of the variable."""
439
+ return self._dims
440
+
441
+ @dims.setter
442
+ def dims(self, dims) -> None:
443
+ if isinstance(dims, str):
444
+ dims = (dims,)
445
+
446
+ if isinstance(dims, list):
447
+ dims = tuple(dims)
448
+
449
+ self._dims = dims or ()
450
+
451
+ self._param_dims_work()
452
+ self._unique_dims()
453
+
454
+ def __getitem__(self, key: str) -> Prior | Any:
455
+ """Return the parameter of the prior."""
456
+ return self.parameters[key]
457
+
458
+ def _checks(self) -> None:
459
+ if not self.centered:
460
+ self._correct_non_centered_distribution()
461
+
462
+ self._parameters_are_at_least_subset_of_pymc()
463
+ self._convert_lists_to_numpy()
464
+ self._parameters_are_correct_type()
465
+
466
+ def _parameters_are_at_least_subset_of_pymc(self) -> None:
467
+ pymc_params = _get_pymc_parameters(self.pymc_distribution)
468
+ if not set(self.parameters.keys()).issubset(pymc_params):
469
+ msg = (
470
+ f"Parameters {set(self.parameters.keys())} "
471
+ "are not a subset of the pymc distribution "
472
+ f"parameters {set(pymc_params)}"
473
+ )
474
+ raise ValueError(msg)
475
+
476
+ def _convert_lists_to_numpy(self) -> None:
477
+ def convert(x):
478
+ if not isinstance(x, list):
479
+ return x
480
+
481
+ return np.array(x)
482
+
483
+ self.parameters = {key: convert(value) for key, value in self.parameters.items()}
484
+
485
+ def _parameters_are_correct_type(self) -> None:
486
+ supported_types = (
487
+ int,
488
+ float,
489
+ np.ndarray,
490
+ Prior,
491
+ pt.TensorVariable,
492
+ VariableFactory,
493
+ )
494
+
495
+ incorrect_types = {
496
+ param: type(value)
497
+ for param, value in self.parameters.items()
498
+ if not isinstance(value, supported_types)
499
+ }
500
+ if incorrect_types:
501
+ msg = (
502
+ "Parameters must be one of the following types: "
503
+ f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}"
504
+ )
505
+ raise ValueError(msg)
506
+
507
+ def _correct_non_centered_distribution(self) -> None:
508
+ if not self.centered and self.distribution not in self.non_centered_distributions:
509
+ raise UnsupportedParameterizationError(
510
+ f"{self.distribution!r} is not supported for non-centered parameterization. "
511
+ f"Choose from {list(self.non_centered_distributions.keys())}"
512
+ )
513
+
514
+ required_parameters = set(self.non_centered_distributions[self.distribution].keys())
515
+
516
+ if set(self.parameters.keys()) < required_parameters:
517
+ msg = " and ".join([f"{param!r}" for param in required_parameters])
518
+ raise ValueError(
519
+ f"Must have at least {msg} parameter for non-centered for {self.distribution!r}"
520
+ )
521
+
522
+ def _unique_dims(self) -> None:
523
+ if not self.dims:
524
+ return
525
+
526
+ if len(self.dims) != len(set(self.dims)):
527
+ raise ValueError("Dims must be unique")
528
+
529
+ def _param_dims_work(self) -> None:
530
+ other_dims = set()
531
+ for value in self.parameters.values():
532
+ if hasattr(value, "dims"):
533
+ other_dims.update(value.dims)
534
+
535
+ if not other_dims.issubset(self.dims):
536
+ raise UnsupportedShapeError(
537
+ f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
538
+ )
539
+
540
+ def __str__(self) -> str:
541
+ """Return a string representation of the prior."""
542
+ param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()])
543
+ param_str = "" if not param_str else f", {param_str}"
544
+
545
+ dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else ""
546
+ centered_str = f", centered={self.centered}" if not self.centered else ""
547
+ transform_str = f', transform="{self.transform}"' if self.transform else ""
548
+ return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})'
549
+
550
+ def __repr__(self) -> str:
551
+ """Return a string representation of the prior."""
552
+ return f"{self}"
553
+
554
+ def _create_parameter(self, param, value, name):
555
+ if not hasattr(value, "create_variable"):
556
+ return value
557
+
558
+ child_name = f"{name}_{param}"
559
+ return self.dim_handler(value.create_variable(child_name), value.dims)
560
+
561
+ def _create_centered_variable(self, name: str):
562
+ parameters = {
563
+ param: self._create_parameter(param, value, name)
564
+ for param, value in self.parameters.items()
565
+ }
566
+ return self.pymc_distribution(name, **parameters, dims=self.dims)
567
+
568
+ def _create_non_centered_variable(self, name: str) -> pt.TensorVariable:
569
+ def handle_variable(var_name: str):
570
+ parameter = self.parameters[var_name]
571
+ if not hasattr(parameter, "create_variable"):
572
+ return parameter
573
+
574
+ return self.dim_handler(
575
+ parameter.create_variable(f"{name}_{var_name}"),
576
+ parameter.dims,
577
+ )
578
+
579
+ defaults = self.non_centered_distributions[self.distribution]
580
+ other_parameters = {
581
+ param: handle_variable(param)
582
+ for param in self.parameters.keys()
583
+ if param not in defaults
584
+ }
585
+ offset = self.pymc_distribution(
586
+ f"{name}_offset",
587
+ **defaults,
588
+ **other_parameters,
589
+ dims=self.dims,
590
+ )
591
+ if "mu" in self.parameters:
592
+ mu = (
593
+ handle_variable("mu")
594
+ if isinstance(self.parameters["mu"], Prior)
595
+ else self.parameters["mu"]
596
+ )
597
+ else:
598
+ mu = 0
599
+
600
+ sigma = (
601
+ handle_variable("sigma")
602
+ if isinstance(self.parameters["sigma"], Prior)
603
+ else self.parameters["sigma"]
604
+ )
605
+
606
+ return pm.Deterministic(
607
+ name,
608
+ mu + sigma * offset,
609
+ dims=self.dims,
610
+ )
611
+
612
+ def create_variable(self, name: str) -> pt.TensorVariable:
613
+ """Create a PyMC variable from the prior.
614
+
615
+ Must be used in a PyMC model context.
616
+
617
+ Parameters
618
+ ----------
619
+ name : str
620
+ The name of the variable.
621
+
622
+ Returns
623
+ -------
624
+ pt.TensorVariable
625
+ The PyMC variable.
626
+
627
+ Examples
628
+ --------
629
+ Create a hierarchical normal variable in larger PyMC model.
630
+
631
+ .. code-block:: python
632
+
633
+ dist = Prior(
634
+ "Normal",
635
+ mu=Prior("Normal"),
636
+ sigma=Prior("HalfNormal"),
637
+ dims="channel",
638
+ )
639
+
640
+ coords = {"channel": ["C1", "C2", "C3"]}
641
+ with pm.Model(coords=coords):
642
+ var = dist.create_variable("var")
643
+
644
+ """
645
+ self.dim_handler = create_dim_handler(self.dims)
646
+
647
+ if self.transform:
648
+ var_name = f"{name}_raw"
649
+
650
+ def transform(var):
651
+ return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims)
652
+ else:
653
+ var_name = name
654
+
655
+ def transform(var):
656
+ return var
657
+
658
+ create_variable = (
659
+ self._create_centered_variable if self.centered else self._create_non_centered_variable
660
+ )
661
+ var = create_variable(name=var_name)
662
+ return transform(var)
663
+
664
+ @property
665
+ def preliz(self):
666
+ """Create an equivalent preliz distribution.
667
+
668
+ Helpful to visualize a distribution when it is univariate.
669
+
670
+ Returns
671
+ -------
672
+ preliz.distributions.Distribution
673
+
674
+ Examples
675
+ --------
676
+ Create a preliz distribution from a prior.
677
+
678
+ .. code-block:: python
679
+
680
+ from pymc_extras.prior import Prior
681
+
682
+ dist = Prior("Gamma", alpha=5, beta=1)
683
+ dist.preliz.plot_pdf()
684
+
685
+ """
686
+ import preliz as pz
687
+
688
+ return getattr(pz, self.distribution)(**self.parameters)
689
+
690
+ def to_dict(self) -> dict[str, Any]:
691
+ """Convert the prior to dictionary format.
692
+
693
+ Returns
694
+ -------
695
+ dict[str, Any]
696
+ The dictionary format of the prior.
697
+
698
+ Examples
699
+ --------
700
+ Convert a prior to the dictionary format.
701
+
702
+ .. code-block:: python
703
+
704
+ from pymc_extras.prior import Prior
705
+
706
+ dist = Prior("Normal", mu=0, sigma=1)
707
+
708
+ dist.to_dict()
709
+
710
+ Convert a hierarchical prior to the dictionary format.
711
+
712
+ .. code-block:: python
713
+
714
+ dist = Prior(
715
+ "Normal",
716
+ mu=Prior("Normal"),
717
+ sigma=Prior("HalfNormal"),
718
+ dims="channel",
719
+ )
720
+
721
+ dist.to_dict()
722
+
723
+ """
724
+ data: dict[str, Any] = {
725
+ "dist": self.distribution,
726
+ }
727
+ if self.parameters:
728
+
729
+ def handle_value(value):
730
+ if isinstance(value, Prior):
731
+ return value.to_dict()
732
+
733
+ if isinstance(value, pt.TensorVariable):
734
+ value = value.eval()
735
+
736
+ if isinstance(value, np.ndarray):
737
+ return value.tolist()
738
+
739
+ if hasattr(value, "to_dict"):
740
+ return value.to_dict()
741
+
742
+ return value
743
+
744
+ data["kwargs"] = {
745
+ param: handle_value(value) for param, value in self.parameters.items()
746
+ }
747
+ if not self.centered:
748
+ data["centered"] = False
749
+
750
+ if self.dims:
751
+ data["dims"] = self.dims
752
+
753
+ if self.transform:
754
+ data["transform"] = self.transform
755
+
756
+ return data
757
+
758
+ @classmethod
759
+ def from_dict(cls, data) -> Prior:
760
+ """Create a Prior from the dictionary format.
761
+
762
+ Parameters
763
+ ----------
764
+ data : dict[str, Any]
765
+ The dictionary format of the prior.
766
+
767
+ Returns
768
+ -------
769
+ Prior
770
+ The prior distribution.
771
+
772
+ Examples
773
+ --------
774
+ Convert prior in the dictionary format to a Prior instance.
775
+
776
+ .. code-block:: python
777
+
778
+ from pymc_extras.prior import Prior
779
+
780
+ data = {
781
+ "dist": "Normal",
782
+ "kwargs": {"mu": 0, "sigma": 1},
783
+ }
784
+
785
+ dist = Prior.from_dict(data)
786
+ dist
787
+ # Prior("Normal", mu=0, sigma=1)
788
+
789
+ """
790
+ if not isinstance(data, dict):
791
+ msg = (
792
+ "Must be a dictionary representation of a prior distribution. "
793
+ f"Not of type: {type(data)}"
794
+ )
795
+ raise ValueError(msg)
796
+
797
+ dist = data["dist"]
798
+ kwargs = data.get("kwargs", {})
799
+
800
+ def handle_value(value):
801
+ if isinstance(value, dict):
802
+ return deserialize(value)
803
+
804
+ if isinstance(value, list):
805
+ return np.array(value)
806
+
807
+ return value
808
+
809
+ kwargs = {param: handle_value(value) for param, value in kwargs.items()}
810
+ centered = data.get("centered", True)
811
+ dims = data.get("dims")
812
+ if isinstance(dims, list):
813
+ dims = tuple(dims)
814
+ transform = data.get("transform")
815
+
816
+ return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
817
+
818
+ def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior:
819
+ """Create a new prior with a given mass constrained within the given bounds.
820
+
821
+ Wrapper around `preliz.maxent`.
822
+
823
+ Parameters
824
+ ----------
825
+ lower : float
826
+ The lower bound.
827
+ upper : float
828
+ The upper bound.
829
+ mass: float = 0.95
830
+ The mass of the distribution to keep within the bounds.
831
+ kwargs : dict
832
+ Additional arguments to pass to `pz.maxent`.
833
+
834
+ Returns
835
+ -------
836
+ Prior
837
+ The maximum entropy prior with a mass constrained to the given bounds.
838
+
839
+ Examples
840
+ --------
841
+ Create a Beta distribution that is constrained to have 95% of the mass
842
+ between 0.5 and 0.8.
843
+
844
+ .. code-block:: python
845
+
846
+ dist = Prior(
847
+ "Beta",
848
+ ).constrain(lower=0.5, upper=0.8)
849
+
850
+ Create a Beta distribution with mean 0.6, that is constrained to
851
+ have 95% of the mass between 0.5 and 0.8.
852
+
853
+ .. code-block:: python
854
+
855
+ dist = Prior(
856
+ "Beta",
857
+ mu=0.6,
858
+ ).constrain(lower=0.5, upper=0.8)
859
+
860
+ """
861
+ from preliz import maxent
862
+
863
+ if self.transform:
864
+ raise ValueError("Can't constrain a transformed variable")
865
+
866
+ if kwargs is None:
867
+ kwargs = {}
868
+ kwargs.setdefault("plot", False)
869
+
870
+ if kwargs["plot"]:
871
+ new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict
872
+ else:
873
+ new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict
874
+
875
+ return Prior(
876
+ self.distribution,
877
+ dims=self.dims,
878
+ transform=self.transform,
879
+ centered=self.centered,
880
+ **new_parameters,
881
+ )
882
+
883
+ def __eq__(self, other) -> bool:
884
+ """Check if two priors are equal."""
885
+ if not isinstance(other, Prior):
886
+ return False
887
+
888
+ try:
889
+ np.testing.assert_equal(self.parameters, other.parameters)
890
+ except AssertionError:
891
+ return False
892
+
893
+ return (
894
+ self.distribution == other.distribution
895
+ and self.dims == other.dims
896
+ and self.centered == other.centered
897
+ and self.transform == other.transform
898
+ )
899
+
900
+ def sample_prior(
901
+ self,
902
+ coords=None,
903
+ name: str = "var",
904
+ **sample_prior_predictive_kwargs,
905
+ ) -> xr.Dataset:
906
+ """Sample the prior distribution for the variable.
907
+
908
+ Parameters
909
+ ----------
910
+ coords : dict[str, list[str]], optional
911
+ The coordinates for the variable, by default None.
912
+ Only required if the dims are specified.
913
+ name : str, optional
914
+ The name of the variable, by default "var".
915
+ sample_prior_predictive_kwargs : dict
916
+ Additional arguments to pass to `pm.sample_prior_predictive`.
917
+
918
+ Returns
919
+ -------
920
+ xr.Dataset
921
+ The dataset of the prior samples.
922
+
923
+ Example
924
+ -------
925
+ Sample from a hierarchical normal distribution.
926
+
927
+ .. code-block:: python
928
+
929
+ dist = Prior(
930
+ "Normal",
931
+ mu=Prior("Normal"),
932
+ sigma=Prior("HalfNormal"),
933
+ dims="channel",
934
+ )
935
+
936
+ coords = {"channel": ["C1", "C2", "C3"]}
937
+ prior = dist.sample_prior(coords=coords)
938
+
939
+ """
940
+ return sample_prior(
941
+ factory=self,
942
+ coords=coords,
943
+ name=name,
944
+ **sample_prior_predictive_kwargs,
945
+ )
946
+
947
+ def __deepcopy__(self, memo) -> Prior:
948
+ """Return a deep copy of the prior."""
949
+ if id(self) in memo:
950
+ return memo[id(self)]
951
+
952
+ copy_obj = Prior(
953
+ self.distribution,
954
+ dims=copy.copy(self.dims),
955
+ centered=self.centered,
956
+ transform=self.transform,
957
+ **copy.deepcopy(self.parameters),
958
+ )
959
+ memo[id(self)] = copy_obj
960
+ return copy_obj
961
+
962
+ def deepcopy(self) -> Prior:
963
+ """Return a deep copy of the prior."""
964
+ return copy.deepcopy(self)
965
+
966
+ def to_graph(self):
967
+ """Generate a graph of the variables.
968
+
969
+ Examples
970
+ --------
971
+ Create the graph for a 2D transformed hierarchical distribution.
972
+
973
+ .. code-block:: python
974
+
975
+ from pymc_extras.prior import Prior
976
+
977
+ mu = Prior(
978
+ "Normal",
979
+ mu=Prior("Normal"),
980
+ sigma=Prior("HalfNormal"),
981
+ dims="channel",
982
+ )
983
+ sigma = Prior("HalfNormal", dims="channel")
984
+ dist = Prior(
985
+ "Normal",
986
+ mu=mu,
987
+ sigma=sigma,
988
+ dims=("channel", "geo"),
989
+ centered=False,
990
+ transform="sigmoid",
991
+ )
992
+
993
+ dist.to_graph()
994
+
995
+ .. image:: /_static/example-graph.png
996
+ :alt: Example graph
997
+
998
+ """
999
+ coords = {name: ["DUMMY"] for name in self.dims}
1000
+ with pm.Model(coords=coords) as model:
1001
+ self.create_variable("var")
1002
+
1003
+ return pm.model_to_graphviz(model)
1004
+
1005
+ def create_likelihood_variable(
1006
+ self,
1007
+ name: str,
1008
+ mu: pt.TensorLike,
1009
+ observed: pt.TensorLike,
1010
+ ) -> pt.TensorVariable:
1011
+ """Create a likelihood variable from the prior.
1012
+
1013
+ Will require that the distribution has a `mu` parameter
1014
+ and that it has not been set in the parameters.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ name : str
1019
+ The name of the variable.
1020
+ mu : pt.TensorLike
1021
+ The mu parameter for the likelihood.
1022
+ observed : pt.TensorLike
1023
+ The observed data.
1024
+
1025
+ Returns
1026
+ -------
1027
+ pt.TensorVariable
1028
+ The PyMC variable.
1029
+
1030
+ Examples
1031
+ --------
1032
+ Create a likelihood variable in a larger PyMC model.
1033
+
1034
+ .. code-block:: python
1035
+
1036
+ import pymc as pm
1037
+
1038
+ dist = Prior("Normal", sigma=Prior("HalfNormal"))
1039
+
1040
+ with pm.Model():
1041
+ # Create the likelihood variable
1042
+ mu = pm.Normal("mu", mu=0, sigma=1)
1043
+ dist.create_likelihood_variable("y", mu=mu, observed=observed)
1044
+
1045
+ """
1046
+ if "mu" not in _get_pymc_parameters(self.pymc_distribution):
1047
+ raise UnsupportedDistributionError(
1048
+ f"Likelihood distribution {self.distribution!r} is not supported."
1049
+ )
1050
+
1051
+ if "mu" in self.parameters:
1052
+ raise MuAlreadyExistsError(self)
1053
+
1054
+ distribution = self.deepcopy()
1055
+ distribution.parameters["mu"] = mu
1056
+ distribution.parameters["observed"] = observed
1057
+ return distribution.create_variable(name)
1058
+
1059
+
1060
+ class VariableNotFound(Exception):
1061
+ """Variable is not found."""
1062
+
1063
+
1064
+ def _remove_random_variable(var: pt.TensorVariable) -> None:
1065
+ if var.name is None:
1066
+ raise ValueError("This isn't removable")
1067
+
1068
+ name: str = var.name
1069
+
1070
+ model = pm.modelcontext(None)
1071
+ for idx, free_rv in enumerate(model.free_RVs):
1072
+ if var == free_rv:
1073
+ index_to_remove = idx
1074
+ break
1075
+ else:
1076
+ raise VariableNotFound(f"Variable {var.name!r} not found")
1077
+
1078
+ var.name = None
1079
+ model.free_RVs.pop(index_to_remove)
1080
+ model.named_vars.pop(name)
1081
+
1082
+
1083
+ @dataclass
1084
+ class Censored:
1085
+ """Create censored random variable.
1086
+
1087
+ Examples
1088
+ --------
1089
+ Create a censored Normal distribution:
1090
+
1091
+ .. code-block:: python
1092
+
1093
+ from pymc_extras.prior import Prior, Censored
1094
+
1095
+ normal = Prior("Normal")
1096
+ censored_normal = Censored(normal, lower=0)
1097
+
1098
+ Create hierarchical censored Normal distribution:
1099
+
1100
+ .. code-block:: python
1101
+
1102
+ from pymc_extras.prior import Prior, Censored
1103
+
1104
+ normal = Prior(
1105
+ "Normal",
1106
+ mu=Prior("Normal"),
1107
+ sigma=Prior("HalfNormal"),
1108
+ dims="channel",
1109
+ )
1110
+ censored_normal = Censored(normal, lower=0)
1111
+
1112
+ coords = {"channel": range(3)}
1113
+ samples = censored_normal.sample_prior(coords=coords)
1114
+
1115
+ """
1116
+
1117
+ distribution: InstanceOf[Prior]
1118
+ lower: float | InstanceOf[pt.TensorVariable] = -np.inf
1119
+ upper: float | InstanceOf[pt.TensorVariable] = np.inf
1120
+
1121
+ def __post_init__(self) -> None:
1122
+ """Check validity at initialization."""
1123
+ if not self.distribution.centered:
1124
+ raise ValueError(
1125
+ "Censored distribution must be centered so that .dist() API can be used on distribution."
1126
+ )
1127
+
1128
+ if self.distribution.transform is not None:
1129
+ raise ValueError(
1130
+ "Censored distribution can't have a transform so that .dist() API can be used on distribution."
1131
+ )
1132
+
1133
+ @property
1134
+ def dims(self) -> tuple[str, ...]:
1135
+ """The dims from the distribution to censor."""
1136
+ return self.distribution.dims
1137
+
1138
+ @dims.setter
1139
+ def dims(self, dims) -> None:
1140
+ self.distribution.dims = dims
1141
+
1142
+ def create_variable(self, name: str) -> pt.TensorVariable:
1143
+ """Create censored random variable."""
1144
+ dist = self.distribution.create_variable(name)
1145
+ _remove_random_variable(var=dist)
1146
+
1147
+ return pm.Censored(
1148
+ name,
1149
+ dist,
1150
+ lower=self.lower,
1151
+ upper=self.upper,
1152
+ dims=self.dims,
1153
+ )
1154
+
1155
+ def to_dict(self) -> dict[str, Any]:
1156
+ """Convert the censored distribution to a dictionary."""
1157
+
1158
+ def handle_value(value):
1159
+ if isinstance(value, pt.TensorVariable):
1160
+ return value.eval().tolist()
1161
+
1162
+ return value
1163
+
1164
+ return {
1165
+ "class": "Censored",
1166
+ "data": {
1167
+ "dist": self.distribution.to_dict(),
1168
+ "lower": handle_value(self.lower),
1169
+ "upper": handle_value(self.upper),
1170
+ },
1171
+ }
1172
+
1173
+ @classmethod
1174
+ def from_dict(cls, data: dict[str, Any]) -> Censored:
1175
+ """Create a censored distribution from a dictionary."""
1176
+ data = data["data"]
1177
+ return cls( # type: ignore
1178
+ distribution=Prior.from_dict(data["dist"]),
1179
+ lower=data["lower"],
1180
+ upper=data["upper"],
1181
+ )
1182
+
1183
+ def sample_prior(
1184
+ self,
1185
+ coords=None,
1186
+ name: str = "variable",
1187
+ **sample_prior_predictive_kwargs,
1188
+ ) -> xr.Dataset:
1189
+ """Sample the prior distribution for the variable.
1190
+
1191
+ Parameters
1192
+ ----------
1193
+ coords : dict[str, list[str]], optional
1194
+ The coordinates for the variable, by default None.
1195
+ Only required if the dims are specified.
1196
+ name : str, optional
1197
+ The name of the variable, by default "var".
1198
+ sample_prior_predictive_kwargs : dict
1199
+ Additional arguments to pass to `pm.sample_prior_predictive`.
1200
+
1201
+ Returns
1202
+ -------
1203
+ xr.Dataset
1204
+ The dataset of the prior samples.
1205
+
1206
+ Example
1207
+ -------
1208
+ Sample from a censored Gamma distribution.
1209
+
1210
+ .. code-block:: python
1211
+
1212
+ gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
1213
+ dist = Censored(gamma, lower=0.5)
1214
+
1215
+ coords = {"channel": ["C1", "C2", "C3"]}
1216
+ prior = dist.sample_prior(coords=coords)
1217
+
1218
+ """
1219
+ return sample_prior(
1220
+ factory=self,
1221
+ coords=coords,
1222
+ name=name,
1223
+ **sample_prior_predictive_kwargs,
1224
+ )
1225
+
1226
+ def to_graph(self):
1227
+ """Generate a graph of the variables.
1228
+
1229
+ Examples
1230
+ --------
1231
+ Create graph for a censored Normal distribution
1232
+
1233
+ .. code-block:: python
1234
+
1235
+ from pymc_extras.prior import Prior, Censored
1236
+
1237
+ normal = Prior("Normal")
1238
+ censored_normal = Censored(normal, lower=0)
1239
+
1240
+ censored_normal.to_graph()
1241
+
1242
+ """
1243
+ coords = {name: ["DUMMY"] for name in self.dims}
1244
+ with pm.Model(coords=coords) as model:
1245
+ self.create_variable("var")
1246
+
1247
+ return pm.model_to_graphviz(model)
1248
+
1249
+ def create_likelihood_variable(
1250
+ self,
1251
+ name: str,
1252
+ mu: pt.TensorLike,
1253
+ observed: pt.TensorLike,
1254
+ ) -> pt.TensorVariable:
1255
+ """Create observed censored variable.
1256
+
1257
+ Will require that the distribution has a `mu` parameter
1258
+ and that it has not been set in the parameters.
1259
+
1260
+ Parameters
1261
+ ----------
1262
+ name : str
1263
+ The name of the variable.
1264
+ mu : pt.TensorLike
1265
+ The mu parameter for the likelihood.
1266
+ observed : pt.TensorLike
1267
+ The observed data.
1268
+
1269
+ Returns
1270
+ -------
1271
+ pt.TensorVariable
1272
+ The PyMC variable.
1273
+
1274
+ Examples
1275
+ --------
1276
+ Create a censored likelihood variable in a larger PyMC model.
1277
+
1278
+ .. code-block:: python
1279
+
1280
+ import pymc as pm
1281
+ from pymc_extras.prior import Prior, Censored
1282
+
1283
+ normal = Prior("Normal", sigma=Prior("HalfNormal"))
1284
+ dist = Censored(normal, lower=0)
1285
+
1286
+ observed = 1
1287
+
1288
+ with pm.Model():
1289
+ # Create the likelihood variable
1290
+ mu = pm.HalfNormal("mu", sigma=1)
1291
+ dist.create_likelihood_variable("y", mu=mu, observed=observed)
1292
+
1293
+ """
1294
+ if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
1295
+ raise UnsupportedDistributionError(
1296
+ f"Likelihood distribution {self.distribution.distribution!r} is not supported."
1297
+ )
1298
+
1299
+ if "mu" in self.distribution.parameters:
1300
+ raise MuAlreadyExistsError(self.distribution)
1301
+
1302
+ distribution = self.distribution.deepcopy()
1303
+ distribution.parameters["mu"] = mu
1304
+
1305
+ dist = distribution.create_variable(name)
1306
+ _remove_random_variable(var=dist)
1307
+
1308
+ return pm.Censored(
1309
+ name,
1310
+ dist,
1311
+ observed=observed,
1312
+ lower=self.lower,
1313
+ upper=self.upper,
1314
+ dims=self.dims,
1315
+ )
1316
+
1317
+
1318
+ class Scaled:
1319
+ """Scaled distribution for numerical stability."""
1320
+
1321
+ def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
1322
+ self.dist = dist
1323
+ self.factor = factor
1324
+
1325
+ @property
1326
+ def dims(self) -> Dims:
1327
+ """The dimensions of the scaled distribution."""
1328
+ return self.dist.dims
1329
+
1330
+ def create_variable(self, name: str) -> pt.TensorVariable:
1331
+ """Create a scaled variable.
1332
+
1333
+ Parameters
1334
+ ----------
1335
+ name : str
1336
+ The name of the variable.
1337
+
1338
+ Returns
1339
+ -------
1340
+ pt.TensorVariable
1341
+ The scaled variable.
1342
+ """
1343
+ var = self.dist.create_variable(f"{name}_unscaled")
1344
+ return pm.Deterministic(name, var * self.factor, dims=self.dims)
1345
+
1346
+
1347
+ def _is_prior_type(data: dict) -> bool:
1348
+ return "dist" in data
1349
+
1350
+
1351
+ def _is_censored_type(data: dict) -> bool:
1352
+ return data.keys() == {"class", "data"} and data["class"] == "Censored"
1353
+
1354
+
1355
+ register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
1356
+ register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)