pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/deserialize.py +224 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/prior.py +1356 -0
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.7.dist-info/METADATA +321 -0
- pymc_extras-0.2.7.dist-info/RECORD +66 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
- pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
pymc_extras/prior.py
ADDED
|
@@ -0,0 +1,1356 @@
|
|
|
1
|
+
"""Class that represents a prior distribution.
|
|
2
|
+
|
|
3
|
+
The `Prior` class is a wrapper around PyMC distributions that allows the user
|
|
4
|
+
to create outside of the PyMC model.
|
|
5
|
+
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
Create a normal prior.
|
|
9
|
+
|
|
10
|
+
.. code-block:: python
|
|
11
|
+
|
|
12
|
+
from pymc_extras.prior import Prior
|
|
13
|
+
|
|
14
|
+
normal = Prior("Normal")
|
|
15
|
+
|
|
16
|
+
Create a hierarchical normal prior by using distributions for the parameters
|
|
17
|
+
and specifying the dims.
|
|
18
|
+
|
|
19
|
+
.. code-block:: python
|
|
20
|
+
|
|
21
|
+
hierarchical_normal = Prior(
|
|
22
|
+
"Normal",
|
|
23
|
+
mu=Prior("Normal"),
|
|
24
|
+
sigma=Prior("HalfNormal"),
|
|
25
|
+
dims="channel",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
Create a non-centered hierarchical normal prior with the `centered` parameter.
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
non_centered_hierarchical_normal = Prior(
|
|
33
|
+
"Normal",
|
|
34
|
+
mu=Prior("Normal"),
|
|
35
|
+
sigma=Prior("HalfNormal"),
|
|
36
|
+
dims="channel",
|
|
37
|
+
# Only change needed to make it non-centered
|
|
38
|
+
centered=False,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
Create a hierarchical beta prior by using Beta distribution, distributions for
|
|
42
|
+
the parameters, and specifying the dims.
|
|
43
|
+
|
|
44
|
+
.. code-block:: python
|
|
45
|
+
|
|
46
|
+
hierarchical_beta = Prior(
|
|
47
|
+
"Beta",
|
|
48
|
+
alpha=Prior("HalfNormal"),
|
|
49
|
+
beta=Prior("HalfNormal"),
|
|
50
|
+
dims="channel",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
Create a transformed hierarchical normal prior by using the `transform`
|
|
54
|
+
parameter. Here the "sigmoid" transformation comes from `pm.math`.
|
|
55
|
+
|
|
56
|
+
.. code-block:: python
|
|
57
|
+
|
|
58
|
+
transformed_hierarchical_normal = Prior(
|
|
59
|
+
"Normal",
|
|
60
|
+
mu=Prior("Normal"),
|
|
61
|
+
sigma=Prior("HalfNormal"),
|
|
62
|
+
transform="sigmoid",
|
|
63
|
+
dims="channel",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
Create a prior with a custom transform function by registering it with
|
|
67
|
+
`register_tensor_transform`.
|
|
68
|
+
|
|
69
|
+
.. code-block:: python
|
|
70
|
+
|
|
71
|
+
from pymc_extras.prior import register_tensor_transform
|
|
72
|
+
|
|
73
|
+
def custom_transform(x):
|
|
74
|
+
return x ** 2
|
|
75
|
+
|
|
76
|
+
register_tensor_transform("square", custom_transform)
|
|
77
|
+
|
|
78
|
+
custom_distribution = Prior("Normal", transform="square")
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
from __future__ import annotations
|
|
83
|
+
|
|
84
|
+
import copy
|
|
85
|
+
|
|
86
|
+
from collections.abc import Callable
|
|
87
|
+
from inspect import signature
|
|
88
|
+
from typing import Any, Protocol, runtime_checkable
|
|
89
|
+
|
|
90
|
+
import numpy as np
|
|
91
|
+
import pymc as pm
|
|
92
|
+
import pytensor.tensor as pt
|
|
93
|
+
import xarray as xr
|
|
94
|
+
|
|
95
|
+
from pydantic import InstanceOf, validate_call
|
|
96
|
+
from pydantic.dataclasses import dataclass
|
|
97
|
+
from pymc.distributions.shape_utils import Dims
|
|
98
|
+
|
|
99
|
+
from pymc_extras.deserialize import deserialize, register_deserialization
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class UnsupportedShapeError(Exception):
|
|
103
|
+
"""Error for when the shapes from variables are not compatible."""
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class UnsupportedDistributionError(Exception):
|
|
107
|
+
"""Error for when an unsupported distribution is used."""
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class UnsupportedParameterizationError(Exception):
|
|
111
|
+
"""The follow parameterization is not supported."""
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class MuAlreadyExistsError(Exception):
|
|
115
|
+
"""Error for when 'mu' is present in Prior."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, distribution: Prior) -> None:
|
|
118
|
+
self.distribution = distribution
|
|
119
|
+
self.message = f"The mu parameter is already defined in {distribution}"
|
|
120
|
+
super().__init__(self.message)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class UnknownTransformError(Exception):
|
|
124
|
+
"""Error for when an unknown transform is used."""
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
|
|
128
|
+
"""Remove leading 'x' from the args."""
|
|
129
|
+
while args and args[0] == "x":
|
|
130
|
+
args.pop(0)
|
|
131
|
+
|
|
132
|
+
return args
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
|
|
136
|
+
"""Take a tensor of dims `dims` and align it to `desired_dims`.
|
|
137
|
+
|
|
138
|
+
Doesn't check for validity of the dims
|
|
139
|
+
|
|
140
|
+
Examples
|
|
141
|
+
--------
|
|
142
|
+
1D to 2D with new dim
|
|
143
|
+
|
|
144
|
+
.. code-block:: python
|
|
145
|
+
|
|
146
|
+
x = np.array([1, 2, 3])
|
|
147
|
+
dims = "channel"
|
|
148
|
+
|
|
149
|
+
desired_dims = ("channel", "group")
|
|
150
|
+
|
|
151
|
+
handle_dims(x, dims, desired_dims)
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
x = pt.as_tensor_variable(x)
|
|
155
|
+
|
|
156
|
+
if np.ndim(x) == 0:
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
dims = dims if isinstance(dims, tuple) else (dims,)
|
|
160
|
+
desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)
|
|
161
|
+
|
|
162
|
+
if difference := set(dims).difference(desired_dims):
|
|
163
|
+
raise UnsupportedShapeError(
|
|
164
|
+
f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
|
|
165
|
+
f"{difference} is missing from the desired dims."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)
|
|
169
|
+
|
|
170
|
+
missing_dims = aligned_dims.sum(axis=0) == 0
|
|
171
|
+
new_idx = aligned_dims.argmax(axis=0)
|
|
172
|
+
|
|
173
|
+
args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)]
|
|
174
|
+
args = _remove_leading_xs(args)
|
|
175
|
+
return x.dimshuffle(*args)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def create_dim_handler(desired_dims: Dims) -> DimHandler:
|
|
182
|
+
"""Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
|
|
183
|
+
|
|
184
|
+
def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
|
|
185
|
+
return handle_dims(x, dims, desired_dims)
|
|
186
|
+
|
|
187
|
+
return func
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _dims_to_str(obj: tuple[str, ...]) -> str:
|
|
191
|
+
if len(obj) == 1:
|
|
192
|
+
return f'"{obj[0]}"'
|
|
193
|
+
|
|
194
|
+
return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")"
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
|
|
198
|
+
if not hasattr(pm, name):
|
|
199
|
+
raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}")
|
|
200
|
+
|
|
201
|
+
return getattr(pm, name)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
Transform = Callable[[pt.TensorLike], pt.TensorLike]
|
|
205
|
+
|
|
206
|
+
CUSTOM_TRANSFORMS: dict[str, Transform] = {}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def register_tensor_transform(name: str, transform: Transform) -> None:
|
|
210
|
+
"""Register a tensor transform function to be used in the `Prior` class.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
name : str
|
|
215
|
+
The name of the transform.
|
|
216
|
+
func : Callable[[pt.TensorLike], pt.TensorLike]
|
|
217
|
+
The function to apply to the tensor.
|
|
218
|
+
|
|
219
|
+
Examples
|
|
220
|
+
--------
|
|
221
|
+
Register a custom transform function.
|
|
222
|
+
|
|
223
|
+
.. code-block:: python
|
|
224
|
+
|
|
225
|
+
from pymc_extras.prior import (
|
|
226
|
+
Prior,
|
|
227
|
+
register_tensor_transform,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def custom_transform(x):
|
|
231
|
+
return x ** 2
|
|
232
|
+
|
|
233
|
+
register_tensor_transform("square", custom_transform)
|
|
234
|
+
|
|
235
|
+
custom_distribution = Prior("Normal", transform="square")
|
|
236
|
+
|
|
237
|
+
"""
|
|
238
|
+
CUSTOM_TRANSFORMS[name] = transform
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _get_transform(name: str):
|
|
242
|
+
if name in CUSTOM_TRANSFORMS:
|
|
243
|
+
return CUSTOM_TRANSFORMS[name]
|
|
244
|
+
|
|
245
|
+
for module in (pt, pm.math):
|
|
246
|
+
if hasattr(module, name):
|
|
247
|
+
break
|
|
248
|
+
else:
|
|
249
|
+
module = None
|
|
250
|
+
|
|
251
|
+
if not module:
|
|
252
|
+
msg = (
|
|
253
|
+
f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
|
|
254
|
+
"If this is a custom function, register it with the "
|
|
255
|
+
"`pymc_extras.prior.register_tensor_transform` function before "
|
|
256
|
+
"previous function call."
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
raise UnknownTransformError(msg)
|
|
260
|
+
|
|
261
|
+
return getattr(module, name)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
|
|
265
|
+
return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"}
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@runtime_checkable
|
|
269
|
+
class VariableFactory(Protocol):
|
|
270
|
+
"""Protocol for something that works like a Prior class."""
|
|
271
|
+
|
|
272
|
+
dims: tuple[str, ...]
|
|
273
|
+
|
|
274
|
+
def create_variable(self, name: str) -> pt.TensorVariable:
|
|
275
|
+
"""Create a TensorVariable."""
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def sample_prior(
|
|
279
|
+
factory: VariableFactory,
|
|
280
|
+
coords=None,
|
|
281
|
+
name: str = "var",
|
|
282
|
+
wrap: bool = False,
|
|
283
|
+
**sample_prior_predictive_kwargs,
|
|
284
|
+
) -> xr.Dataset:
|
|
285
|
+
"""Sample the prior for an arbitrary VariableFactory.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
factory : VariableFactory
|
|
290
|
+
The factory to sample from.
|
|
291
|
+
coords : dict[str, list[str]], optional
|
|
292
|
+
The coordinates for the variable, by default None.
|
|
293
|
+
Only required if the dims are specified.
|
|
294
|
+
name : str, optional
|
|
295
|
+
The name of the variable, by default "var".
|
|
296
|
+
wrap : bool, optional
|
|
297
|
+
Whether to wrap the variable in a `pm.Deterministic` node, by default False.
|
|
298
|
+
sample_prior_predictive_kwargs : dict
|
|
299
|
+
Additional arguments to pass to `pm.sample_prior_predictive`.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
xr.Dataset
|
|
304
|
+
The dataset of the prior samples.
|
|
305
|
+
|
|
306
|
+
Example
|
|
307
|
+
-------
|
|
308
|
+
Sample from an arbitrary variable factory.
|
|
309
|
+
|
|
310
|
+
.. code-block:: python
|
|
311
|
+
|
|
312
|
+
import pymc as pm
|
|
313
|
+
|
|
314
|
+
import pytensor.tensor as pt
|
|
315
|
+
|
|
316
|
+
from pymc_extras.prior import sample_prior
|
|
317
|
+
|
|
318
|
+
class CustomVariableDefinition:
|
|
319
|
+
def __init__(self, dims, n: int):
|
|
320
|
+
self.dims = dims
|
|
321
|
+
self.n = n
|
|
322
|
+
|
|
323
|
+
def create_variable(self, name: str) -> "TensorVariable":
|
|
324
|
+
x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
|
|
325
|
+
return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
|
|
326
|
+
|
|
327
|
+
cubic = CustomVariableDefinition(dims=("channel",), n=3)
|
|
328
|
+
coords = {"channel": ["C1", "C2", "C3"]}
|
|
329
|
+
# Doesn't include the return value
|
|
330
|
+
prior = sample_prior(cubic, coords=coords)
|
|
331
|
+
|
|
332
|
+
prior_with = sample_prior(cubic, coords=coords, wrap=True)
|
|
333
|
+
|
|
334
|
+
"""
|
|
335
|
+
coords = coords or {}
|
|
336
|
+
|
|
337
|
+
if isinstance(factory.dims, str):
|
|
338
|
+
dims = (factory.dims,)
|
|
339
|
+
else:
|
|
340
|
+
dims = factory.dims
|
|
341
|
+
|
|
342
|
+
if missing_keys := set(dims) - set(coords.keys()):
|
|
343
|
+
raise KeyError(f"Coords are missing the following dims: {missing_keys}")
|
|
344
|
+
|
|
345
|
+
with pm.Model(coords=coords) as model:
|
|
346
|
+
if wrap:
|
|
347
|
+
pm.Deterministic(name, factory.create_variable(name), dims=factory.dims)
|
|
348
|
+
else:
|
|
349
|
+
factory.create_variable(name)
|
|
350
|
+
|
|
351
|
+
return pm.sample_prior_predictive(
|
|
352
|
+
model=model,
|
|
353
|
+
**sample_prior_predictive_kwargs,
|
|
354
|
+
).prior
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class Prior:
|
|
358
|
+
"""A class to represent a prior distribution.
|
|
359
|
+
|
|
360
|
+
Make use of the various helper methods to understand the distributions
|
|
361
|
+
better.
|
|
362
|
+
|
|
363
|
+
- `preliz` attribute to get the equivalent distribution in `preliz`
|
|
364
|
+
- `sample_prior` method to sample from the prior
|
|
365
|
+
- `graph` get a dummy model graph with the distribution
|
|
366
|
+
- `constrain` to shift the distribution to a different range
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
distribution : str
|
|
371
|
+
The name of PyMC distribution.
|
|
372
|
+
dims : Dims, optional
|
|
373
|
+
The dimensions of the variable, by default None
|
|
374
|
+
centered : bool, optional
|
|
375
|
+
Whether the variable is centered or not, by default True.
|
|
376
|
+
Only allowed for Normal distribution.
|
|
377
|
+
transform : str, optional
|
|
378
|
+
The name of the transform to apply to the variable after it is
|
|
379
|
+
created, by default None or no transform. The transformation must
|
|
380
|
+
be registered with `register_tensor_transform` function or
|
|
381
|
+
be available in either `pytensor.tensor` or `pymc.math`.
|
|
382
|
+
|
|
383
|
+
"""
|
|
384
|
+
|
|
385
|
+
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
|
|
386
|
+
non_centered_distributions: dict[str, dict[str, float]] = {
|
|
387
|
+
"Normal": {"mu": 0, "sigma": 1},
|
|
388
|
+
"StudentT": {"mu": 0, "sigma": 1},
|
|
389
|
+
"ZeroSumNormal": {"sigma": 1},
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
pymc_distribution: type[pm.Distribution]
|
|
393
|
+
pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
|
|
394
|
+
|
|
395
|
+
@validate_call
|
|
396
|
+
def __init__(
|
|
397
|
+
self,
|
|
398
|
+
distribution: str,
|
|
399
|
+
*,
|
|
400
|
+
dims: Dims | None = None,
|
|
401
|
+
centered: bool = True,
|
|
402
|
+
transform: str | None = None,
|
|
403
|
+
**parameters,
|
|
404
|
+
) -> None:
|
|
405
|
+
self.distribution = distribution
|
|
406
|
+
self.parameters = parameters
|
|
407
|
+
self.dims = dims
|
|
408
|
+
self.centered = centered
|
|
409
|
+
self.transform = transform
|
|
410
|
+
|
|
411
|
+
self._checks()
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def distribution(self) -> str:
|
|
415
|
+
"""The name of the PyMC distribution."""
|
|
416
|
+
return self._distribution
|
|
417
|
+
|
|
418
|
+
@distribution.setter
|
|
419
|
+
def distribution(self, distribution: str) -> None:
|
|
420
|
+
if hasattr(self, "_distribution"):
|
|
421
|
+
raise AttributeError("Can't change the distribution")
|
|
422
|
+
|
|
423
|
+
self._distribution = distribution
|
|
424
|
+
self.pymc_distribution = _get_pymc_distribution(distribution)
|
|
425
|
+
|
|
426
|
+
@property
|
|
427
|
+
def transform(self) -> str | None:
|
|
428
|
+
"""The name of the transform to apply to the variable after it is created."""
|
|
429
|
+
return self._transform
|
|
430
|
+
|
|
431
|
+
@transform.setter
|
|
432
|
+
def transform(self, transform: str | None) -> None:
|
|
433
|
+
self._transform = transform
|
|
434
|
+
self.pytensor_transform = not transform or _get_transform(transform) # type: ignore
|
|
435
|
+
|
|
436
|
+
@property
|
|
437
|
+
def dims(self) -> Dims:
|
|
438
|
+
"""The dimensions of the variable."""
|
|
439
|
+
return self._dims
|
|
440
|
+
|
|
441
|
+
@dims.setter
|
|
442
|
+
def dims(self, dims) -> None:
|
|
443
|
+
if isinstance(dims, str):
|
|
444
|
+
dims = (dims,)
|
|
445
|
+
|
|
446
|
+
if isinstance(dims, list):
|
|
447
|
+
dims = tuple(dims)
|
|
448
|
+
|
|
449
|
+
self._dims = dims or ()
|
|
450
|
+
|
|
451
|
+
self._param_dims_work()
|
|
452
|
+
self._unique_dims()
|
|
453
|
+
|
|
454
|
+
def __getitem__(self, key: str) -> Prior | Any:
|
|
455
|
+
"""Return the parameter of the prior."""
|
|
456
|
+
return self.parameters[key]
|
|
457
|
+
|
|
458
|
+
def _checks(self) -> None:
|
|
459
|
+
if not self.centered:
|
|
460
|
+
self._correct_non_centered_distribution()
|
|
461
|
+
|
|
462
|
+
self._parameters_are_at_least_subset_of_pymc()
|
|
463
|
+
self._convert_lists_to_numpy()
|
|
464
|
+
self._parameters_are_correct_type()
|
|
465
|
+
|
|
466
|
+
def _parameters_are_at_least_subset_of_pymc(self) -> None:
|
|
467
|
+
pymc_params = _get_pymc_parameters(self.pymc_distribution)
|
|
468
|
+
if not set(self.parameters.keys()).issubset(pymc_params):
|
|
469
|
+
msg = (
|
|
470
|
+
f"Parameters {set(self.parameters.keys())} "
|
|
471
|
+
"are not a subset of the pymc distribution "
|
|
472
|
+
f"parameters {set(pymc_params)}"
|
|
473
|
+
)
|
|
474
|
+
raise ValueError(msg)
|
|
475
|
+
|
|
476
|
+
def _convert_lists_to_numpy(self) -> None:
|
|
477
|
+
def convert(x):
|
|
478
|
+
if not isinstance(x, list):
|
|
479
|
+
return x
|
|
480
|
+
|
|
481
|
+
return np.array(x)
|
|
482
|
+
|
|
483
|
+
self.parameters = {key: convert(value) for key, value in self.parameters.items()}
|
|
484
|
+
|
|
485
|
+
def _parameters_are_correct_type(self) -> None:
|
|
486
|
+
supported_types = (
|
|
487
|
+
int,
|
|
488
|
+
float,
|
|
489
|
+
np.ndarray,
|
|
490
|
+
Prior,
|
|
491
|
+
pt.TensorVariable,
|
|
492
|
+
VariableFactory,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
incorrect_types = {
|
|
496
|
+
param: type(value)
|
|
497
|
+
for param, value in self.parameters.items()
|
|
498
|
+
if not isinstance(value, supported_types)
|
|
499
|
+
}
|
|
500
|
+
if incorrect_types:
|
|
501
|
+
msg = (
|
|
502
|
+
"Parameters must be one of the following types: "
|
|
503
|
+
f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}"
|
|
504
|
+
)
|
|
505
|
+
raise ValueError(msg)
|
|
506
|
+
|
|
507
|
+
def _correct_non_centered_distribution(self) -> None:
|
|
508
|
+
if not self.centered and self.distribution not in self.non_centered_distributions:
|
|
509
|
+
raise UnsupportedParameterizationError(
|
|
510
|
+
f"{self.distribution!r} is not supported for non-centered parameterization. "
|
|
511
|
+
f"Choose from {list(self.non_centered_distributions.keys())}"
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
required_parameters = set(self.non_centered_distributions[self.distribution].keys())
|
|
515
|
+
|
|
516
|
+
if set(self.parameters.keys()) < required_parameters:
|
|
517
|
+
msg = " and ".join([f"{param!r}" for param in required_parameters])
|
|
518
|
+
raise ValueError(
|
|
519
|
+
f"Must have at least {msg} parameter for non-centered for {self.distribution!r}"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def _unique_dims(self) -> None:
|
|
523
|
+
if not self.dims:
|
|
524
|
+
return
|
|
525
|
+
|
|
526
|
+
if len(self.dims) != len(set(self.dims)):
|
|
527
|
+
raise ValueError("Dims must be unique")
|
|
528
|
+
|
|
529
|
+
def _param_dims_work(self) -> None:
|
|
530
|
+
other_dims = set()
|
|
531
|
+
for value in self.parameters.values():
|
|
532
|
+
if hasattr(value, "dims"):
|
|
533
|
+
other_dims.update(value.dims)
|
|
534
|
+
|
|
535
|
+
if not other_dims.issubset(self.dims):
|
|
536
|
+
raise UnsupportedShapeError(
|
|
537
|
+
f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
def __str__(self) -> str:
|
|
541
|
+
"""Return a string representation of the prior."""
|
|
542
|
+
param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()])
|
|
543
|
+
param_str = "" if not param_str else f", {param_str}"
|
|
544
|
+
|
|
545
|
+
dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else ""
|
|
546
|
+
centered_str = f", centered={self.centered}" if not self.centered else ""
|
|
547
|
+
transform_str = f', transform="{self.transform}"' if self.transform else ""
|
|
548
|
+
return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})'
|
|
549
|
+
|
|
550
|
+
def __repr__(self) -> str:
|
|
551
|
+
"""Return a string representation of the prior."""
|
|
552
|
+
return f"{self}"
|
|
553
|
+
|
|
554
|
+
def _create_parameter(self, param, value, name):
|
|
555
|
+
if not hasattr(value, "create_variable"):
|
|
556
|
+
return value
|
|
557
|
+
|
|
558
|
+
child_name = f"{name}_{param}"
|
|
559
|
+
return self.dim_handler(value.create_variable(child_name), value.dims)
|
|
560
|
+
|
|
561
|
+
def _create_centered_variable(self, name: str):
|
|
562
|
+
parameters = {
|
|
563
|
+
param: self._create_parameter(param, value, name)
|
|
564
|
+
for param, value in self.parameters.items()
|
|
565
|
+
}
|
|
566
|
+
return self.pymc_distribution(name, **parameters, dims=self.dims)
|
|
567
|
+
|
|
568
|
+
def _create_non_centered_variable(self, name: str) -> pt.TensorVariable:
|
|
569
|
+
def handle_variable(var_name: str):
|
|
570
|
+
parameter = self.parameters[var_name]
|
|
571
|
+
if not hasattr(parameter, "create_variable"):
|
|
572
|
+
return parameter
|
|
573
|
+
|
|
574
|
+
return self.dim_handler(
|
|
575
|
+
parameter.create_variable(f"{name}_{var_name}"),
|
|
576
|
+
parameter.dims,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
defaults = self.non_centered_distributions[self.distribution]
|
|
580
|
+
other_parameters = {
|
|
581
|
+
param: handle_variable(param)
|
|
582
|
+
for param in self.parameters.keys()
|
|
583
|
+
if param not in defaults
|
|
584
|
+
}
|
|
585
|
+
offset = self.pymc_distribution(
|
|
586
|
+
f"{name}_offset",
|
|
587
|
+
**defaults,
|
|
588
|
+
**other_parameters,
|
|
589
|
+
dims=self.dims,
|
|
590
|
+
)
|
|
591
|
+
if "mu" in self.parameters:
|
|
592
|
+
mu = (
|
|
593
|
+
handle_variable("mu")
|
|
594
|
+
if isinstance(self.parameters["mu"], Prior)
|
|
595
|
+
else self.parameters["mu"]
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
mu = 0
|
|
599
|
+
|
|
600
|
+
sigma = (
|
|
601
|
+
handle_variable("sigma")
|
|
602
|
+
if isinstance(self.parameters["sigma"], Prior)
|
|
603
|
+
else self.parameters["sigma"]
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
return pm.Deterministic(
|
|
607
|
+
name,
|
|
608
|
+
mu + sigma * offset,
|
|
609
|
+
dims=self.dims,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def create_variable(self, name: str) -> pt.TensorVariable:
|
|
613
|
+
"""Create a PyMC variable from the prior.
|
|
614
|
+
|
|
615
|
+
Must be used in a PyMC model context.
|
|
616
|
+
|
|
617
|
+
Parameters
|
|
618
|
+
----------
|
|
619
|
+
name : str
|
|
620
|
+
The name of the variable.
|
|
621
|
+
|
|
622
|
+
Returns
|
|
623
|
+
-------
|
|
624
|
+
pt.TensorVariable
|
|
625
|
+
The PyMC variable.
|
|
626
|
+
|
|
627
|
+
Examples
|
|
628
|
+
--------
|
|
629
|
+
Create a hierarchical normal variable in larger PyMC model.
|
|
630
|
+
|
|
631
|
+
.. code-block:: python
|
|
632
|
+
|
|
633
|
+
dist = Prior(
|
|
634
|
+
"Normal",
|
|
635
|
+
mu=Prior("Normal"),
|
|
636
|
+
sigma=Prior("HalfNormal"),
|
|
637
|
+
dims="channel",
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
coords = {"channel": ["C1", "C2", "C3"]}
|
|
641
|
+
with pm.Model(coords=coords):
|
|
642
|
+
var = dist.create_variable("var")
|
|
643
|
+
|
|
644
|
+
"""
|
|
645
|
+
self.dim_handler = create_dim_handler(self.dims)
|
|
646
|
+
|
|
647
|
+
if self.transform:
|
|
648
|
+
var_name = f"{name}_raw"
|
|
649
|
+
|
|
650
|
+
def transform(var):
|
|
651
|
+
return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims)
|
|
652
|
+
else:
|
|
653
|
+
var_name = name
|
|
654
|
+
|
|
655
|
+
def transform(var):
|
|
656
|
+
return var
|
|
657
|
+
|
|
658
|
+
create_variable = (
|
|
659
|
+
self._create_centered_variable if self.centered else self._create_non_centered_variable
|
|
660
|
+
)
|
|
661
|
+
var = create_variable(name=var_name)
|
|
662
|
+
return transform(var)
|
|
663
|
+
|
|
664
|
+
@property
|
|
665
|
+
def preliz(self):
|
|
666
|
+
"""Create an equivalent preliz distribution.
|
|
667
|
+
|
|
668
|
+
Helpful to visualize a distribution when it is univariate.
|
|
669
|
+
|
|
670
|
+
Returns
|
|
671
|
+
-------
|
|
672
|
+
preliz.distributions.Distribution
|
|
673
|
+
|
|
674
|
+
Examples
|
|
675
|
+
--------
|
|
676
|
+
Create a preliz distribution from a prior.
|
|
677
|
+
|
|
678
|
+
.. code-block:: python
|
|
679
|
+
|
|
680
|
+
from pymc_extras.prior import Prior
|
|
681
|
+
|
|
682
|
+
dist = Prior("Gamma", alpha=5, beta=1)
|
|
683
|
+
dist.preliz.plot_pdf()
|
|
684
|
+
|
|
685
|
+
"""
|
|
686
|
+
import preliz as pz
|
|
687
|
+
|
|
688
|
+
return getattr(pz, self.distribution)(**self.parameters)
|
|
689
|
+
|
|
690
|
+
def to_dict(self) -> dict[str, Any]:
|
|
691
|
+
"""Convert the prior to dictionary format.
|
|
692
|
+
|
|
693
|
+
Returns
|
|
694
|
+
-------
|
|
695
|
+
dict[str, Any]
|
|
696
|
+
The dictionary format of the prior.
|
|
697
|
+
|
|
698
|
+
Examples
|
|
699
|
+
--------
|
|
700
|
+
Convert a prior to the dictionary format.
|
|
701
|
+
|
|
702
|
+
.. code-block:: python
|
|
703
|
+
|
|
704
|
+
from pymc_extras.prior import Prior
|
|
705
|
+
|
|
706
|
+
dist = Prior("Normal", mu=0, sigma=1)
|
|
707
|
+
|
|
708
|
+
dist.to_dict()
|
|
709
|
+
|
|
710
|
+
Convert a hierarchical prior to the dictionary format.
|
|
711
|
+
|
|
712
|
+
.. code-block:: python
|
|
713
|
+
|
|
714
|
+
dist = Prior(
|
|
715
|
+
"Normal",
|
|
716
|
+
mu=Prior("Normal"),
|
|
717
|
+
sigma=Prior("HalfNormal"),
|
|
718
|
+
dims="channel",
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
dist.to_dict()
|
|
722
|
+
|
|
723
|
+
"""
|
|
724
|
+
data: dict[str, Any] = {
|
|
725
|
+
"dist": self.distribution,
|
|
726
|
+
}
|
|
727
|
+
if self.parameters:
|
|
728
|
+
|
|
729
|
+
def handle_value(value):
|
|
730
|
+
if isinstance(value, Prior):
|
|
731
|
+
return value.to_dict()
|
|
732
|
+
|
|
733
|
+
if isinstance(value, pt.TensorVariable):
|
|
734
|
+
value = value.eval()
|
|
735
|
+
|
|
736
|
+
if isinstance(value, np.ndarray):
|
|
737
|
+
return value.tolist()
|
|
738
|
+
|
|
739
|
+
if hasattr(value, "to_dict"):
|
|
740
|
+
return value.to_dict()
|
|
741
|
+
|
|
742
|
+
return value
|
|
743
|
+
|
|
744
|
+
data["kwargs"] = {
|
|
745
|
+
param: handle_value(value) for param, value in self.parameters.items()
|
|
746
|
+
}
|
|
747
|
+
if not self.centered:
|
|
748
|
+
data["centered"] = False
|
|
749
|
+
|
|
750
|
+
if self.dims:
|
|
751
|
+
data["dims"] = self.dims
|
|
752
|
+
|
|
753
|
+
if self.transform:
|
|
754
|
+
data["transform"] = self.transform
|
|
755
|
+
|
|
756
|
+
return data
|
|
757
|
+
|
|
758
|
+
@classmethod
|
|
759
|
+
def from_dict(cls, data) -> Prior:
|
|
760
|
+
"""Create a Prior from the dictionary format.
|
|
761
|
+
|
|
762
|
+
Parameters
|
|
763
|
+
----------
|
|
764
|
+
data : dict[str, Any]
|
|
765
|
+
The dictionary format of the prior.
|
|
766
|
+
|
|
767
|
+
Returns
|
|
768
|
+
-------
|
|
769
|
+
Prior
|
|
770
|
+
The prior distribution.
|
|
771
|
+
|
|
772
|
+
Examples
|
|
773
|
+
--------
|
|
774
|
+
Convert prior in the dictionary format to a Prior instance.
|
|
775
|
+
|
|
776
|
+
.. code-block:: python
|
|
777
|
+
|
|
778
|
+
from pymc_extras.prior import Prior
|
|
779
|
+
|
|
780
|
+
data = {
|
|
781
|
+
"dist": "Normal",
|
|
782
|
+
"kwargs": {"mu": 0, "sigma": 1},
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
dist = Prior.from_dict(data)
|
|
786
|
+
dist
|
|
787
|
+
# Prior("Normal", mu=0, sigma=1)
|
|
788
|
+
|
|
789
|
+
"""
|
|
790
|
+
if not isinstance(data, dict):
|
|
791
|
+
msg = (
|
|
792
|
+
"Must be a dictionary representation of a prior distribution. "
|
|
793
|
+
f"Not of type: {type(data)}"
|
|
794
|
+
)
|
|
795
|
+
raise ValueError(msg)
|
|
796
|
+
|
|
797
|
+
dist = data["dist"]
|
|
798
|
+
kwargs = data.get("kwargs", {})
|
|
799
|
+
|
|
800
|
+
def handle_value(value):
|
|
801
|
+
if isinstance(value, dict):
|
|
802
|
+
return deserialize(value)
|
|
803
|
+
|
|
804
|
+
if isinstance(value, list):
|
|
805
|
+
return np.array(value)
|
|
806
|
+
|
|
807
|
+
return value
|
|
808
|
+
|
|
809
|
+
kwargs = {param: handle_value(value) for param, value in kwargs.items()}
|
|
810
|
+
centered = data.get("centered", True)
|
|
811
|
+
dims = data.get("dims")
|
|
812
|
+
if isinstance(dims, list):
|
|
813
|
+
dims = tuple(dims)
|
|
814
|
+
transform = data.get("transform")
|
|
815
|
+
|
|
816
|
+
return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
|
|
817
|
+
|
|
818
|
+
def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior:
|
|
819
|
+
"""Create a new prior with a given mass constrained within the given bounds.
|
|
820
|
+
|
|
821
|
+
Wrapper around `preliz.maxent`.
|
|
822
|
+
|
|
823
|
+
Parameters
|
|
824
|
+
----------
|
|
825
|
+
lower : float
|
|
826
|
+
The lower bound.
|
|
827
|
+
upper : float
|
|
828
|
+
The upper bound.
|
|
829
|
+
mass: float = 0.95
|
|
830
|
+
The mass of the distribution to keep within the bounds.
|
|
831
|
+
kwargs : dict
|
|
832
|
+
Additional arguments to pass to `pz.maxent`.
|
|
833
|
+
|
|
834
|
+
Returns
|
|
835
|
+
-------
|
|
836
|
+
Prior
|
|
837
|
+
The maximum entropy prior with a mass constrained to the given bounds.
|
|
838
|
+
|
|
839
|
+
Examples
|
|
840
|
+
--------
|
|
841
|
+
Create a Beta distribution that is constrained to have 95% of the mass
|
|
842
|
+
between 0.5 and 0.8.
|
|
843
|
+
|
|
844
|
+
.. code-block:: python
|
|
845
|
+
|
|
846
|
+
dist = Prior(
|
|
847
|
+
"Beta",
|
|
848
|
+
).constrain(lower=0.5, upper=0.8)
|
|
849
|
+
|
|
850
|
+
Create a Beta distribution with mean 0.6, that is constrained to
|
|
851
|
+
have 95% of the mass between 0.5 and 0.8.
|
|
852
|
+
|
|
853
|
+
.. code-block:: python
|
|
854
|
+
|
|
855
|
+
dist = Prior(
|
|
856
|
+
"Beta",
|
|
857
|
+
mu=0.6,
|
|
858
|
+
).constrain(lower=0.5, upper=0.8)
|
|
859
|
+
|
|
860
|
+
"""
|
|
861
|
+
from preliz import maxent
|
|
862
|
+
|
|
863
|
+
if self.transform:
|
|
864
|
+
raise ValueError("Can't constrain a transformed variable")
|
|
865
|
+
|
|
866
|
+
if kwargs is None:
|
|
867
|
+
kwargs = {}
|
|
868
|
+
kwargs.setdefault("plot", False)
|
|
869
|
+
|
|
870
|
+
if kwargs["plot"]:
|
|
871
|
+
new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict
|
|
872
|
+
else:
|
|
873
|
+
new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict
|
|
874
|
+
|
|
875
|
+
return Prior(
|
|
876
|
+
self.distribution,
|
|
877
|
+
dims=self.dims,
|
|
878
|
+
transform=self.transform,
|
|
879
|
+
centered=self.centered,
|
|
880
|
+
**new_parameters,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
def __eq__(self, other) -> bool:
|
|
884
|
+
"""Check if two priors are equal."""
|
|
885
|
+
if not isinstance(other, Prior):
|
|
886
|
+
return False
|
|
887
|
+
|
|
888
|
+
try:
|
|
889
|
+
np.testing.assert_equal(self.parameters, other.parameters)
|
|
890
|
+
except AssertionError:
|
|
891
|
+
return False
|
|
892
|
+
|
|
893
|
+
return (
|
|
894
|
+
self.distribution == other.distribution
|
|
895
|
+
and self.dims == other.dims
|
|
896
|
+
and self.centered == other.centered
|
|
897
|
+
and self.transform == other.transform
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
def sample_prior(
|
|
901
|
+
self,
|
|
902
|
+
coords=None,
|
|
903
|
+
name: str = "var",
|
|
904
|
+
**sample_prior_predictive_kwargs,
|
|
905
|
+
) -> xr.Dataset:
|
|
906
|
+
"""Sample the prior distribution for the variable.
|
|
907
|
+
|
|
908
|
+
Parameters
|
|
909
|
+
----------
|
|
910
|
+
coords : dict[str, list[str]], optional
|
|
911
|
+
The coordinates for the variable, by default None.
|
|
912
|
+
Only required if the dims are specified.
|
|
913
|
+
name : str, optional
|
|
914
|
+
The name of the variable, by default "var".
|
|
915
|
+
sample_prior_predictive_kwargs : dict
|
|
916
|
+
Additional arguments to pass to `pm.sample_prior_predictive`.
|
|
917
|
+
|
|
918
|
+
Returns
|
|
919
|
+
-------
|
|
920
|
+
xr.Dataset
|
|
921
|
+
The dataset of the prior samples.
|
|
922
|
+
|
|
923
|
+
Example
|
|
924
|
+
-------
|
|
925
|
+
Sample from a hierarchical normal distribution.
|
|
926
|
+
|
|
927
|
+
.. code-block:: python
|
|
928
|
+
|
|
929
|
+
dist = Prior(
|
|
930
|
+
"Normal",
|
|
931
|
+
mu=Prior("Normal"),
|
|
932
|
+
sigma=Prior("HalfNormal"),
|
|
933
|
+
dims="channel",
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
coords = {"channel": ["C1", "C2", "C3"]}
|
|
937
|
+
prior = dist.sample_prior(coords=coords)
|
|
938
|
+
|
|
939
|
+
"""
|
|
940
|
+
return sample_prior(
|
|
941
|
+
factory=self,
|
|
942
|
+
coords=coords,
|
|
943
|
+
name=name,
|
|
944
|
+
**sample_prior_predictive_kwargs,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
def __deepcopy__(self, memo) -> Prior:
|
|
948
|
+
"""Return a deep copy of the prior."""
|
|
949
|
+
if id(self) in memo:
|
|
950
|
+
return memo[id(self)]
|
|
951
|
+
|
|
952
|
+
copy_obj = Prior(
|
|
953
|
+
self.distribution,
|
|
954
|
+
dims=copy.copy(self.dims),
|
|
955
|
+
centered=self.centered,
|
|
956
|
+
transform=self.transform,
|
|
957
|
+
**copy.deepcopy(self.parameters),
|
|
958
|
+
)
|
|
959
|
+
memo[id(self)] = copy_obj
|
|
960
|
+
return copy_obj
|
|
961
|
+
|
|
962
|
+
def deepcopy(self) -> Prior:
|
|
963
|
+
"""Return a deep copy of the prior."""
|
|
964
|
+
return copy.deepcopy(self)
|
|
965
|
+
|
|
966
|
+
def to_graph(self):
|
|
967
|
+
"""Generate a graph of the variables.
|
|
968
|
+
|
|
969
|
+
Examples
|
|
970
|
+
--------
|
|
971
|
+
Create the graph for a 2D transformed hierarchical distribution.
|
|
972
|
+
|
|
973
|
+
.. code-block:: python
|
|
974
|
+
|
|
975
|
+
from pymc_extras.prior import Prior
|
|
976
|
+
|
|
977
|
+
mu = Prior(
|
|
978
|
+
"Normal",
|
|
979
|
+
mu=Prior("Normal"),
|
|
980
|
+
sigma=Prior("HalfNormal"),
|
|
981
|
+
dims="channel",
|
|
982
|
+
)
|
|
983
|
+
sigma = Prior("HalfNormal", dims="channel")
|
|
984
|
+
dist = Prior(
|
|
985
|
+
"Normal",
|
|
986
|
+
mu=mu,
|
|
987
|
+
sigma=sigma,
|
|
988
|
+
dims=("channel", "geo"),
|
|
989
|
+
centered=False,
|
|
990
|
+
transform="sigmoid",
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
dist.to_graph()
|
|
994
|
+
|
|
995
|
+
.. image:: /_static/example-graph.png
|
|
996
|
+
:alt: Example graph
|
|
997
|
+
|
|
998
|
+
"""
|
|
999
|
+
coords = {name: ["DUMMY"] for name in self.dims}
|
|
1000
|
+
with pm.Model(coords=coords) as model:
|
|
1001
|
+
self.create_variable("var")
|
|
1002
|
+
|
|
1003
|
+
return pm.model_to_graphviz(model)
|
|
1004
|
+
|
|
1005
|
+
def create_likelihood_variable(
|
|
1006
|
+
self,
|
|
1007
|
+
name: str,
|
|
1008
|
+
mu: pt.TensorLike,
|
|
1009
|
+
observed: pt.TensorLike,
|
|
1010
|
+
) -> pt.TensorVariable:
|
|
1011
|
+
"""Create a likelihood variable from the prior.
|
|
1012
|
+
|
|
1013
|
+
Will require that the distribution has a `mu` parameter
|
|
1014
|
+
and that it has not been set in the parameters.
|
|
1015
|
+
|
|
1016
|
+
Parameters
|
|
1017
|
+
----------
|
|
1018
|
+
name : str
|
|
1019
|
+
The name of the variable.
|
|
1020
|
+
mu : pt.TensorLike
|
|
1021
|
+
The mu parameter for the likelihood.
|
|
1022
|
+
observed : pt.TensorLike
|
|
1023
|
+
The observed data.
|
|
1024
|
+
|
|
1025
|
+
Returns
|
|
1026
|
+
-------
|
|
1027
|
+
pt.TensorVariable
|
|
1028
|
+
The PyMC variable.
|
|
1029
|
+
|
|
1030
|
+
Examples
|
|
1031
|
+
--------
|
|
1032
|
+
Create a likelihood variable in a larger PyMC model.
|
|
1033
|
+
|
|
1034
|
+
.. code-block:: python
|
|
1035
|
+
|
|
1036
|
+
import pymc as pm
|
|
1037
|
+
|
|
1038
|
+
dist = Prior("Normal", sigma=Prior("HalfNormal"))
|
|
1039
|
+
|
|
1040
|
+
with pm.Model():
|
|
1041
|
+
# Create the likelihood variable
|
|
1042
|
+
mu = pm.Normal("mu", mu=0, sigma=1)
|
|
1043
|
+
dist.create_likelihood_variable("y", mu=mu, observed=observed)
|
|
1044
|
+
|
|
1045
|
+
"""
|
|
1046
|
+
if "mu" not in _get_pymc_parameters(self.pymc_distribution):
|
|
1047
|
+
raise UnsupportedDistributionError(
|
|
1048
|
+
f"Likelihood distribution {self.distribution!r} is not supported."
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
if "mu" in self.parameters:
|
|
1052
|
+
raise MuAlreadyExistsError(self)
|
|
1053
|
+
|
|
1054
|
+
distribution = self.deepcopy()
|
|
1055
|
+
distribution.parameters["mu"] = mu
|
|
1056
|
+
distribution.parameters["observed"] = observed
|
|
1057
|
+
return distribution.create_variable(name)
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
class VariableNotFound(Exception):
|
|
1061
|
+
"""Variable is not found."""
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def _remove_random_variable(var: pt.TensorVariable) -> None:
|
|
1065
|
+
if var.name is None:
|
|
1066
|
+
raise ValueError("This isn't removable")
|
|
1067
|
+
|
|
1068
|
+
name: str = var.name
|
|
1069
|
+
|
|
1070
|
+
model = pm.modelcontext(None)
|
|
1071
|
+
for idx, free_rv in enumerate(model.free_RVs):
|
|
1072
|
+
if var == free_rv:
|
|
1073
|
+
index_to_remove = idx
|
|
1074
|
+
break
|
|
1075
|
+
else:
|
|
1076
|
+
raise VariableNotFound(f"Variable {var.name!r} not found")
|
|
1077
|
+
|
|
1078
|
+
var.name = None
|
|
1079
|
+
model.free_RVs.pop(index_to_remove)
|
|
1080
|
+
model.named_vars.pop(name)
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
@dataclass
|
|
1084
|
+
class Censored:
|
|
1085
|
+
"""Create censored random variable.
|
|
1086
|
+
|
|
1087
|
+
Examples
|
|
1088
|
+
--------
|
|
1089
|
+
Create a censored Normal distribution:
|
|
1090
|
+
|
|
1091
|
+
.. code-block:: python
|
|
1092
|
+
|
|
1093
|
+
from pymc_extras.prior import Prior, Censored
|
|
1094
|
+
|
|
1095
|
+
normal = Prior("Normal")
|
|
1096
|
+
censored_normal = Censored(normal, lower=0)
|
|
1097
|
+
|
|
1098
|
+
Create hierarchical censored Normal distribution:
|
|
1099
|
+
|
|
1100
|
+
.. code-block:: python
|
|
1101
|
+
|
|
1102
|
+
from pymc_extras.prior import Prior, Censored
|
|
1103
|
+
|
|
1104
|
+
normal = Prior(
|
|
1105
|
+
"Normal",
|
|
1106
|
+
mu=Prior("Normal"),
|
|
1107
|
+
sigma=Prior("HalfNormal"),
|
|
1108
|
+
dims="channel",
|
|
1109
|
+
)
|
|
1110
|
+
censored_normal = Censored(normal, lower=0)
|
|
1111
|
+
|
|
1112
|
+
coords = {"channel": range(3)}
|
|
1113
|
+
samples = censored_normal.sample_prior(coords=coords)
|
|
1114
|
+
|
|
1115
|
+
"""
|
|
1116
|
+
|
|
1117
|
+
distribution: InstanceOf[Prior]
|
|
1118
|
+
lower: float | InstanceOf[pt.TensorVariable] = -np.inf
|
|
1119
|
+
upper: float | InstanceOf[pt.TensorVariable] = np.inf
|
|
1120
|
+
|
|
1121
|
+
def __post_init__(self) -> None:
|
|
1122
|
+
"""Check validity at initialization."""
|
|
1123
|
+
if not self.distribution.centered:
|
|
1124
|
+
raise ValueError(
|
|
1125
|
+
"Censored distribution must be centered so that .dist() API can be used on distribution."
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
if self.distribution.transform is not None:
|
|
1129
|
+
raise ValueError(
|
|
1130
|
+
"Censored distribution can't have a transform so that .dist() API can be used on distribution."
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
@property
|
|
1134
|
+
def dims(self) -> tuple[str, ...]:
|
|
1135
|
+
"""The dims from the distribution to censor."""
|
|
1136
|
+
return self.distribution.dims
|
|
1137
|
+
|
|
1138
|
+
@dims.setter
|
|
1139
|
+
def dims(self, dims) -> None:
|
|
1140
|
+
self.distribution.dims = dims
|
|
1141
|
+
|
|
1142
|
+
def create_variable(self, name: str) -> pt.TensorVariable:
|
|
1143
|
+
"""Create censored random variable."""
|
|
1144
|
+
dist = self.distribution.create_variable(name)
|
|
1145
|
+
_remove_random_variable(var=dist)
|
|
1146
|
+
|
|
1147
|
+
return pm.Censored(
|
|
1148
|
+
name,
|
|
1149
|
+
dist,
|
|
1150
|
+
lower=self.lower,
|
|
1151
|
+
upper=self.upper,
|
|
1152
|
+
dims=self.dims,
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1156
|
+
"""Convert the censored distribution to a dictionary."""
|
|
1157
|
+
|
|
1158
|
+
def handle_value(value):
|
|
1159
|
+
if isinstance(value, pt.TensorVariable):
|
|
1160
|
+
return value.eval().tolist()
|
|
1161
|
+
|
|
1162
|
+
return value
|
|
1163
|
+
|
|
1164
|
+
return {
|
|
1165
|
+
"class": "Censored",
|
|
1166
|
+
"data": {
|
|
1167
|
+
"dist": self.distribution.to_dict(),
|
|
1168
|
+
"lower": handle_value(self.lower),
|
|
1169
|
+
"upper": handle_value(self.upper),
|
|
1170
|
+
},
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
@classmethod
|
|
1174
|
+
def from_dict(cls, data: dict[str, Any]) -> Censored:
|
|
1175
|
+
"""Create a censored distribution from a dictionary."""
|
|
1176
|
+
data = data["data"]
|
|
1177
|
+
return cls( # type: ignore
|
|
1178
|
+
distribution=Prior.from_dict(data["dist"]),
|
|
1179
|
+
lower=data["lower"],
|
|
1180
|
+
upper=data["upper"],
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
def sample_prior(
|
|
1184
|
+
self,
|
|
1185
|
+
coords=None,
|
|
1186
|
+
name: str = "variable",
|
|
1187
|
+
**sample_prior_predictive_kwargs,
|
|
1188
|
+
) -> xr.Dataset:
|
|
1189
|
+
"""Sample the prior distribution for the variable.
|
|
1190
|
+
|
|
1191
|
+
Parameters
|
|
1192
|
+
----------
|
|
1193
|
+
coords : dict[str, list[str]], optional
|
|
1194
|
+
The coordinates for the variable, by default None.
|
|
1195
|
+
Only required if the dims are specified.
|
|
1196
|
+
name : str, optional
|
|
1197
|
+
The name of the variable, by default "var".
|
|
1198
|
+
sample_prior_predictive_kwargs : dict
|
|
1199
|
+
Additional arguments to pass to `pm.sample_prior_predictive`.
|
|
1200
|
+
|
|
1201
|
+
Returns
|
|
1202
|
+
-------
|
|
1203
|
+
xr.Dataset
|
|
1204
|
+
The dataset of the prior samples.
|
|
1205
|
+
|
|
1206
|
+
Example
|
|
1207
|
+
-------
|
|
1208
|
+
Sample from a censored Gamma distribution.
|
|
1209
|
+
|
|
1210
|
+
.. code-block:: python
|
|
1211
|
+
|
|
1212
|
+
gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
|
|
1213
|
+
dist = Censored(gamma, lower=0.5)
|
|
1214
|
+
|
|
1215
|
+
coords = {"channel": ["C1", "C2", "C3"]}
|
|
1216
|
+
prior = dist.sample_prior(coords=coords)
|
|
1217
|
+
|
|
1218
|
+
"""
|
|
1219
|
+
return sample_prior(
|
|
1220
|
+
factory=self,
|
|
1221
|
+
coords=coords,
|
|
1222
|
+
name=name,
|
|
1223
|
+
**sample_prior_predictive_kwargs,
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
def to_graph(self):
|
|
1227
|
+
"""Generate a graph of the variables.
|
|
1228
|
+
|
|
1229
|
+
Examples
|
|
1230
|
+
--------
|
|
1231
|
+
Create graph for a censored Normal distribution
|
|
1232
|
+
|
|
1233
|
+
.. code-block:: python
|
|
1234
|
+
|
|
1235
|
+
from pymc_extras.prior import Prior, Censored
|
|
1236
|
+
|
|
1237
|
+
normal = Prior("Normal")
|
|
1238
|
+
censored_normal = Censored(normal, lower=0)
|
|
1239
|
+
|
|
1240
|
+
censored_normal.to_graph()
|
|
1241
|
+
|
|
1242
|
+
"""
|
|
1243
|
+
coords = {name: ["DUMMY"] for name in self.dims}
|
|
1244
|
+
with pm.Model(coords=coords) as model:
|
|
1245
|
+
self.create_variable("var")
|
|
1246
|
+
|
|
1247
|
+
return pm.model_to_graphviz(model)
|
|
1248
|
+
|
|
1249
|
+
def create_likelihood_variable(
|
|
1250
|
+
self,
|
|
1251
|
+
name: str,
|
|
1252
|
+
mu: pt.TensorLike,
|
|
1253
|
+
observed: pt.TensorLike,
|
|
1254
|
+
) -> pt.TensorVariable:
|
|
1255
|
+
"""Create observed censored variable.
|
|
1256
|
+
|
|
1257
|
+
Will require that the distribution has a `mu` parameter
|
|
1258
|
+
and that it has not been set in the parameters.
|
|
1259
|
+
|
|
1260
|
+
Parameters
|
|
1261
|
+
----------
|
|
1262
|
+
name : str
|
|
1263
|
+
The name of the variable.
|
|
1264
|
+
mu : pt.TensorLike
|
|
1265
|
+
The mu parameter for the likelihood.
|
|
1266
|
+
observed : pt.TensorLike
|
|
1267
|
+
The observed data.
|
|
1268
|
+
|
|
1269
|
+
Returns
|
|
1270
|
+
-------
|
|
1271
|
+
pt.TensorVariable
|
|
1272
|
+
The PyMC variable.
|
|
1273
|
+
|
|
1274
|
+
Examples
|
|
1275
|
+
--------
|
|
1276
|
+
Create a censored likelihood variable in a larger PyMC model.
|
|
1277
|
+
|
|
1278
|
+
.. code-block:: python
|
|
1279
|
+
|
|
1280
|
+
import pymc as pm
|
|
1281
|
+
from pymc_extras.prior import Prior, Censored
|
|
1282
|
+
|
|
1283
|
+
normal = Prior("Normal", sigma=Prior("HalfNormal"))
|
|
1284
|
+
dist = Censored(normal, lower=0)
|
|
1285
|
+
|
|
1286
|
+
observed = 1
|
|
1287
|
+
|
|
1288
|
+
with pm.Model():
|
|
1289
|
+
# Create the likelihood variable
|
|
1290
|
+
mu = pm.HalfNormal("mu", sigma=1)
|
|
1291
|
+
dist.create_likelihood_variable("y", mu=mu, observed=observed)
|
|
1292
|
+
|
|
1293
|
+
"""
|
|
1294
|
+
if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
|
|
1295
|
+
raise UnsupportedDistributionError(
|
|
1296
|
+
f"Likelihood distribution {self.distribution.distribution!r} is not supported."
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
if "mu" in self.distribution.parameters:
|
|
1300
|
+
raise MuAlreadyExistsError(self.distribution)
|
|
1301
|
+
|
|
1302
|
+
distribution = self.distribution.deepcopy()
|
|
1303
|
+
distribution.parameters["mu"] = mu
|
|
1304
|
+
|
|
1305
|
+
dist = distribution.create_variable(name)
|
|
1306
|
+
_remove_random_variable(var=dist)
|
|
1307
|
+
|
|
1308
|
+
return pm.Censored(
|
|
1309
|
+
name,
|
|
1310
|
+
dist,
|
|
1311
|
+
observed=observed,
|
|
1312
|
+
lower=self.lower,
|
|
1313
|
+
upper=self.upper,
|
|
1314
|
+
dims=self.dims,
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
|
|
1318
|
+
class Scaled:
|
|
1319
|
+
"""Scaled distribution for numerical stability."""
|
|
1320
|
+
|
|
1321
|
+
def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
|
|
1322
|
+
self.dist = dist
|
|
1323
|
+
self.factor = factor
|
|
1324
|
+
|
|
1325
|
+
@property
|
|
1326
|
+
def dims(self) -> Dims:
|
|
1327
|
+
"""The dimensions of the scaled distribution."""
|
|
1328
|
+
return self.dist.dims
|
|
1329
|
+
|
|
1330
|
+
def create_variable(self, name: str) -> pt.TensorVariable:
|
|
1331
|
+
"""Create a scaled variable.
|
|
1332
|
+
|
|
1333
|
+
Parameters
|
|
1334
|
+
----------
|
|
1335
|
+
name : str
|
|
1336
|
+
The name of the variable.
|
|
1337
|
+
|
|
1338
|
+
Returns
|
|
1339
|
+
-------
|
|
1340
|
+
pt.TensorVariable
|
|
1341
|
+
The scaled variable.
|
|
1342
|
+
"""
|
|
1343
|
+
var = self.dist.create_variable(f"{name}_unscaled")
|
|
1344
|
+
return pm.Deterministic(name, var * self.factor, dims=self.dims)
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
def _is_prior_type(data: dict) -> bool:
|
|
1348
|
+
return "dist" in data
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
def _is_censored_type(data: dict) -> bool:
|
|
1352
|
+
return data.keys() == {"class", "data"} and data["class"] == "Censored"
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
|
|
1356
|
+
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
|